Automated g4 rollback of changelist 196683444
authorPeter Hawkins <phawkins@google.com>
Tue, 15 May 2018 17:33:48 +0000 (10:33 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 15 May 2018 17:36:11 +0000 (10:36 -0700)
PiperOrigin-RevId: 196691101

20 files changed:
tensorflow/compiler/aot/tests/tfcompile_test.cc
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/xla_cpu_device.cc
tensorflow/compiler/jit/xla_device.cc
tensorflow/compiler/jit/xla_device.h
tensorflow/compiler/jit/xla_device_context.cc
tensorflow/compiler/jit/xla_device_context.h
tensorflow/compiler/jit/xla_gpu_device.cc
tensorflow/compiler/jit/xla_launch_util.cc
tensorflow/compiler/tests/BUILD
tensorflow/compiler/tests/xla_device_gpu_test.py [deleted file]
tensorflow/compiler/tests/xla_device_test.py
tensorflow/compiler/tf2xla/BUILD
tensorflow/compiler/tf2xla/kernels/retval_op.cc
tensorflow/compiler/tf2xla/xla_compiler.cc
tensorflow/compiler/tf2xla/xla_compiler.h
tensorflow/compiler/tf2xla/xla_compiler_test.cc
tensorflow/compiler/tf2xla/xla_context.cc
tensorflow/compiler/tf2xla/xla_context.h
tensorflow/compiler/tf2xla/xla_op_kernel.cc

index fee4628..868d752 100644 (file)
@@ -551,16 +551,14 @@ TEST(TFCompileTest, HloProfiling) {
   auto header = HasSubstr("Execution profile for");
   auto total_cycles_profile_line = HasSubstr("[total]");
   auto dot_profile_line = HasSubstr(
-      "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
+      "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
       "%arg1.0.1)");
   auto add_profile_line = HasSubstr(
-      "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
+      "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
       "%arg1.0.1)");
   auto tuple_profile_line = HasSubstr(
       "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
-      "%dot.0.4, f32[2,2]{1,0} %add.0.6)");
-  auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
-  auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
+      "%dot.0.2, f32[2,2]{1,0} %add.0.5)");
 
   EXPECT_THAT(hlo_profile_lines,
               IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
index 9d85634..86a9fd3 100644 (file)
@@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
   // this is more obviously correct.)
   core::ScopedUnref cache_ref(cache);
 
-  const XlaDevice::Metadata* metadata = nullptr;
+  const XlaDevice::Metadata* metadata;
   Status s = XlaDevice::GetMetadata(ctx, &metadata);
   bool allocate_xla_tensors = s.ok();
 
@@ -153,9 +153,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
   options.graph_def_version = ctx->function_library()->graph_def_version();
   options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
   options.device_allocator = xla_allocator;
-  if (metadata) {
-    options.shape_representation_fn = metadata->shape_representation_fn();
-  }
+  // TODO(b/77671268): We don't set variable_representation_shape_fn here. This
+  // is restricted to Variables, but we need something like this to apply to
+  // normal Tensors too.
 
   const XlaCompiler::CompilationResult* kernel;
   xla::LocalExecutable* executable;
@@ -164,11 +164,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
   for (int i : constants_) {
     constant_args.insert({i, ctx->input(i)});
   }
-  XlaCompiler::CompileOptions compile_options;
-  compile_options.is_entry_computation = true;
-  OP_REQUIRES_OK(
-      ctx, cache->Compile(options, function_, constant_args, variables, ctx,
-                          &kernel, &executable, &compile_options));
+  OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
+                                     variables, ctx, &kernel, &executable,
+                                     /*compile_options=*/nullptr));
 
   VLOG(1) << "Executing XLA Computation...";
 
index ea9e036..bc07dbd 100644 (file)
@@ -50,11 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
   (void)registrations;
 
   std::unique_ptr<XlaDevice> device;
-  TF_RETURN_IF_ERROR(
-      XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
-                        name_prefix, registration,
-                        /*transfer_as_literal=*/false,
-                        /*shape_representation_fn=*/{}, &device));
+  TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
+                                       DEVICE_CPU_XLA_JIT, options, name_prefix,
+                                       registration,
+                                       /*transfer_as_literal=*/false, &device));
   devices->push_back(device.release());
   return Status::OK();
 }
index 9ee5b04..70263b1 100644 (file)
@@ -110,9 +110,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
     const string& jit_device_name, const SessionOptions& options,
     const string& name_prefix,
     const XlaOpRegistry::DeviceRegistration& registration,
-    bool transfer_as_literal,
-    const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
-    std::unique_ptr<XlaDevice>* device) {
+    bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
   VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
           << device_ordinal;
 
@@ -131,19 +129,17 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
       DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
       strings::StrCat("device: ", device_name, " device"));
 
-  device->reset(new XlaDevice(
-      options, attrs, device_ordinal, DeviceType(jit_device_name),
-      platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
+  device->reset(new XlaDevice(options, attrs, device_ordinal,
+                              DeviceType(jit_device_name),
+                              platform.ValueOrDie(), transfer_as_literal));
   return Status::OK();
 }
 
-XlaDevice::Metadata::Metadata(
-    int device_ordinal, se::Platform* platform, const DeviceType& device_type,
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
+                              const DeviceType& device_type)
     : device_ordinal_(device_ordinal),
       device_type_(device_type),
-      platform_(platform),
-      shape_representation_fn_(std::move(shape_representation_fn)) {}
+      platform_(platform) {}
 
 int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
 
@@ -174,20 +170,17 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
   return Status::OK();
 }
 
-XlaDevice::XlaDevice(
-    const SessionOptions& options, const DeviceAttributes& attrs,
-    int device_ordinal, const DeviceType& jit_device_name,
-    se::Platform* platform, bool transfer_as_literal,
-    const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
+XlaDevice::XlaDevice(const SessionOptions& options,
+                     const DeviceAttributes& attrs, int device_ordinal,
+                     const DeviceType& jit_device_name, se::Platform* platform,
+                     bool transfer_as_literal)
     : LocalDevice(options, attrs),
-      xla_metadata_(device_ordinal, platform, jit_device_name,
-                    shape_representation_fn),
+      xla_metadata_(device_ordinal, platform, jit_device_name),
       device_ordinal_(device_ordinal),
       jit_device_name_(jit_device_name),
       xla_allocator_(nullptr),
       platform_(platform),
-      transfer_as_literal_(transfer_as_literal),
-      shape_representation_fn_(shape_representation_fn) {
+      transfer_as_literal_(transfer_as_literal) {
   VLOG(1) << "Created XLA device " << jit_device_name;
 }
 
@@ -239,8 +232,8 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
     // gpu_device_info_->default_context.
     gpu_device_info_ = absl::make_unique<GpuDeviceInfo>();
     gpu_device_info_->stream = stream;
-    gpu_device_info_->default_context = new XlaDeviceContext(
-        stream, client(), transfer_as_literal_, shape_representation_fn_);
+    gpu_device_info_->default_context =
+        new XlaDeviceContext(stream, client(), transfer_as_literal_);
     set_tensorflow_gpu_device_info(gpu_device_info_.get());
   }
 
@@ -254,8 +247,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
   TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
   // Call GetAllocator for the side-effect of ensuring the allocator is created.
   GetAllocator({});
-  auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
-                                  shape_representation_fn_);
+  auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
   for (Node* n : graph->nodes()) {
     VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
     ctx->Ref();
@@ -302,8 +294,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
     Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
     Notification n;
     TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
-    XlaTransferManager manager(stream, client(), transfer_as_literal_,
-                               shape_representation_fn_);
+    XlaTransferManager manager(stream, client(), transfer_as_literal_);
     manager.CopyCPUTensorToDevice(&parsed, this, &copy,
                                   [&n, &status](const Status& s) {
                                     status = s;
index d5d345d..3ae8730 100644 (file)
@@ -17,7 +17,8 @@ limitations under the License.
 // runtime.
 //
 // Operators assigned to an XlaDevice are compiled into XLA computations.
-// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
+// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
+// is managed by XLA.
 //
 // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
 // under different names (e.g., XLA_CPU or XLA_GPU).
@@ -26,7 +27,6 @@ limitations under the License.
 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
 
 #include "tensorflow/compiler/jit/xla_tensor.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/core/common_runtime/device_factory.h"
@@ -50,8 +50,7 @@ class XlaDevice : public LocalDevice {
   class Metadata {
    public:
     Metadata(int device_ordinal, se::Platform* platform,
-             const DeviceType& device_type,
-             XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+             const DeviceType& device_type);
 
     // The index of the device on this host.
     int device_ordinal() const;
@@ -59,15 +58,11 @@ class XlaDevice : public LocalDevice {
     se::Platform* platform() const;
     xla::LocalClient* client() const;
     const DeviceType& jit_device_type() const;
-    const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
-      return shape_representation_fn_;
-    }
 
    private:
     const int device_ordinal_;
     const DeviceType device_type_;
     se::Platform* platform_;  // Not owned.
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
 
     TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
   };
@@ -81,19 +76,16 @@ class XlaDevice : public LocalDevice {
   // 'transfer_as_literal' is true if device<->host transfers must be done using
   // XLA's TransferLiteral{To,From}Device interface. If false, we can use
   // ThenMemcpy instead.
-  static Status Create(
-      const string& platform_name, const string& device_name,
-      int device_ordinal, const string& jit_device_name,
-      const SessionOptions& options, const string& name_prefix,
-      const XlaOpRegistry::DeviceRegistration& registration,
-      bool transfer_as_literal,
-      const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
-      std::unique_ptr<XlaDevice>* device);
+  static Status Create(const string& platform_name, const string& device_name,
+                       int device_ordinal, const string& jit_device_name,
+                       const SessionOptions& options, const string& name_prefix,
+                       const XlaOpRegistry::DeviceRegistration& registration,
+                       bool transfer_as_literal,
+                       std::unique_ptr<XlaDevice>* device);
 
   XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
             int device_ordinal, const DeviceType& jit_device_name,
-            se::Platform* platform, bool transfer_as_literal,
-            const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
+            se::Platform* platform, bool transfer_as_literal);
   ~XlaDevice() override;
 
   Allocator* GetAllocator(AllocatorAttributes attr) override;
@@ -124,8 +116,8 @@ class XlaDevice : public LocalDevice {
   // The name of the device that is used to compile Ops for this XlaDevice.
   DeviceType jit_device_name_;
   // Memory allocator associated with this device.
-  Allocator* xla_allocator_;  // Not owned.
-  se::Platform* platform_;    // Not owned.
+  Allocator* xla_allocator_;                   // Not owned.
+  se::Platform* platform_;                     // Not owned.
   // Stream associated with this device. Operations enqueued on this
   // stream are executed on the device. Operations include data
   // copying back and forth between CPU and the device, and
@@ -134,7 +126,6 @@ class XlaDevice : public LocalDevice {
   // Must we use XLA's transfer manager for correct host<->device transfers? if
   // false, we can use ThenMemcpy() instead.
   bool transfer_as_literal_;
-  XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
 
   // If set, holds default device context (that we must Unref)
   // and its stream.
index ff30b62..bf8c188 100644 (file)
@@ -47,14 +47,13 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
 
 void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
 
-XlaTransferManager::XlaTransferManager(
-    se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+XlaTransferManager::XlaTransferManager(se::Stream* stream,
+                                       xla::LocalClient* client,
+                                       bool transfer_as_literal)
     : stream_(stream),
       client_(client),
       transfer_manager_(client->backend().transfer_manager()),
-      transfer_as_literal_(transfer_as_literal),
-      shape_representation_fn_(std::move(shape_representation_fn)) {}
+      transfer_as_literal_(transfer_as_literal) {}
 
 Status XlaTransferManager::TransferLiteralToDevice(
     const Tensor& host_tensor, Tensor* device_tensor) const {
@@ -77,15 +76,7 @@ Status XlaTransferManager::TransferLiteralFromDevice(
                       transfer_manager_->TransferLiteralFromDevice(
                           stream_->parent(), shaped_buffer));
   VLOG(1) << "Transfer from device as literal: " << literal->ToString();
-  Tensor tensor;
-  TF_RETURN_IF_ERROR(
-      LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
-  // Reshape the tensor back to its declared shape.
-  if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
-    return errors::Internal(
-        "Tensor::CopyFrom failed when copying from XLA device to CPU");
-  }
-  return Status::OK();
+  return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor);
 }
 
 void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
@@ -105,17 +96,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
 
     XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
     CHECK(xla_tensor);
-
-    TensorShape shape;
-    if (shape_representation_fn_) {
-      shape = shape_representation_fn_(device_tensor->shape(),
-                                       device_tensor->dtype());
-    } else {
-      shape = device_tensor->shape();
-    }
     if (!xla_tensor->has_shaped_buffer()) {
       Status s = xla_tensor->AllocateShapedBuffer(
-          device_tensor->dtype(), shape, client_,
+          device_tensor->dtype(), device_tensor->shape(), client_,
           stream_->parent()->device_ordinal());
       if (!s.ok()) {
         done(s);
@@ -123,18 +106,12 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
       }
     }
 
+    se::DeviceMemoryBase dev_dst_ptr =
+        XlaTensor::DeviceMemoryFromTensor(*device_tensor);
     Status status;
     if (transfer_as_literal_) {
-      Tensor reshaped_cpu_tensor;
-      if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
-        done(errors::Internal(
-            "Tensor::CopyFrom failed when copying from CPU to XLA device"));
-        return;
-      }
-      status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
+      status = TransferLiteralToDevice(*cpu_tensor, device_tensor);
     } else {
-      se::DeviceMemoryBase dev_dst_ptr =
-          XlaTensor::DeviceMemoryFromTensor(*device_tensor);
       stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
       // TODO(hpucha): Make this asynchronous.
       Status block_status = stream_->BlockHostUntilDone();
@@ -194,11 +171,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
   done(Status::OK());
 }
 
-XlaDeviceContext::XlaDeviceContext(
-    se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
-    XlaCompiler::ShapeRepresentationFn shape_representation_fn)
-    : manager_(stream, client, transfer_as_literal,
-               std::move(shape_representation_fn)) {}
+XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
+                                   bool transfer_as_literal)
+    : manager_(stream, client, transfer_as_literal) {}
 
 void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
                                              Device* device,
index 9af9655..d7f5f1d 100644 (file)
@@ -19,7 +19,6 @@ limitations under the License.
 #include <memory>
 
 #include "tensorflow/compiler/jit/xla_tensor.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/xla/client/global_data.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/core/framework/allocator.h"
@@ -46,9 +45,8 @@ class XlaDeviceAllocator : public Allocator {
 // Helper class for managing data transfers between host and XLA devices.
 class XlaTransferManager {
  public:
-  explicit XlaTransferManager(
-      se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
-      XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+  explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client,
+                              bool transfer_as_literal);
 
   void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
                              Tensor* device_tensor, StatusCallback done) const;
@@ -71,8 +69,7 @@ class XlaTransferManager {
   // Transfer manager, for marshalling data to and from the device.
   xla::TransferManager* transfer_manager_;
   // True if we must use XLA's TransferManager for correct device transfers.
-  const bool transfer_as_literal_;
-  const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+  bool transfer_as_literal_;
 };
 
 // DeviceContext for operators assigned to XlaDevice devices. The
@@ -80,9 +77,8 @@ class XlaTransferManager {
 // wraps the methods in XlaTransferManager.
 class XlaDeviceContext : public DeviceContext {
  public:
-  explicit XlaDeviceContext(
-      se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
-      XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+  explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
+                            bool transfer_as_literal);
 
   void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
                              Tensor* device_tensor,
index 26842fb..a8afbf9 100644 (file)
@@ -48,8 +48,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
   Status status =
       XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
                         name_prefix, registration,
-                        /*transfer_as_literal=*/false,
-                        /*shape_representation_fn=*/{}, &device);
+                        /*transfer_as_literal=*/false, &device);
   if (!status.ok()) {
     // Treat failures as non-fatal; there might not be a GPU in the machine.
     VLOG(1) << "Failed to create XLA_GPU device: " << status;
index d0c7a93..6a0f557 100644 (file)
@@ -195,6 +195,11 @@ void XlaComputationLaunchContext::PopulateOutputs(
 
         OP_REQUIRES_OK(
             ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+        if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
+          OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer(
+                                  const_tensor.dtype(), const_tensor.shape(),
+                                  client_, stream->parent()->device_ordinal()));
+        }
 
         Device* device = dynamic_cast<Device*>(ctx->device());
         OP_REQUIRES(ctx, device != nullptr,
index 213ab95..96dfc8d 100644 (file)
@@ -42,7 +42,7 @@ py_library(
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client",
         "//tensorflow/python:client_testlib",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform",
         "//tensorflow/python:random_seed",
         "//tensorflow/python:session",
@@ -58,7 +58,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:training",
@@ -72,7 +72,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:training",
@@ -93,7 +93,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -111,7 +111,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:bitwise_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:math_ops_gen",
         "//tensorflow/python:nn_ops",
@@ -127,7 +127,7 @@ tf_xla_py_test(
     tags = ["optonly"],
     deps = [
         ":xla_test",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:random_ops",
     ],
@@ -141,7 +141,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:training",
@@ -156,7 +156,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:training",
@@ -170,7 +170,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -184,7 +184,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:array_ops_gen",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:gradient_checker",
         "//tensorflow/python:gradients",
         "//tensorflow/python:math_ops",
@@ -209,7 +209,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:array_ops_gen",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:gradient_checker",
         "//tensorflow/python:gradients",
         "//tensorflow/python:math_ops",
@@ -225,7 +225,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:nn",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:nn_ops_gen",
@@ -241,7 +241,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:nn",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:nn_ops_gen",
@@ -263,7 +263,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:nn",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:nn_ops_gen",
@@ -291,7 +291,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:data_flow_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -307,7 +307,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -326,7 +326,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:layers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:nn",
@@ -346,7 +346,7 @@ tf_xla_py_test(
         "//tensorflow/contrib/signal:signal_py",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:extra_py_tests_deps",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:spectral_ops",
     ],
@@ -360,7 +360,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:data_flow_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -372,7 +372,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:training",
@@ -388,7 +388,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -403,7 +403,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:image_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -431,7 +431,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:nn",
         "//tensorflow/python:nn_ops_gen",
         "//tensorflow/python:platform_test",
@@ -446,7 +446,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -458,7 +458,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:training",
@@ -472,7 +472,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -485,7 +485,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:control_flow_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -498,7 +498,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:nn_ops_gen",
         "//tensorflow/python:platform_test",
@@ -513,7 +513,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:nn_ops_gen",
         "//tensorflow/python:platform_test",
@@ -530,7 +530,7 @@ tf_xla_py_test(
     ],
     deps = [
         ":xla_test",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:random_ops",
     ],
@@ -545,7 +545,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:errors",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -561,7 +561,7 @@ tf_xla_py_test(
         "//tensorflow/compiler/tf2xla/python:xla",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:errors",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -574,7 +574,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
     ],
 )
 
@@ -586,7 +586,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -598,7 +598,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:training",
@@ -613,7 +613,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -626,7 +626,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:math_ops_gen",
         "//tensorflow/python:platform_test",
@@ -641,7 +641,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -657,7 +657,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:data_flow_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -670,7 +670,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/contrib/stateless",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -684,7 +684,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:math_ops_gen",
         "//tensorflow/python:nn_ops",
@@ -703,7 +703,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:platform_test",
     ],
@@ -716,7 +716,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:nn_ops",
         "//tensorflow/python:nn_ops_gen",
@@ -730,7 +730,7 @@ tf_xla_py_test(
     srcs = ["fused_batchnorm_test.py"],
     deps = [
         ":xla_test",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:math_ops_gen",
         "//tensorflow/python:nn",
@@ -749,7 +749,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:math_ops_gen",
         "//tensorflow/python:nn_ops",
@@ -768,7 +768,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/compiler/tf2xla/python:xla",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
         "//tensorflow/python:training",
     ],
@@ -783,7 +783,7 @@ tf_xla_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:data_flow_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -795,7 +795,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -808,34 +808,21 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
 
-tf_xla_py_test(
+cuda_py_test(
     name = "xla_device_test",
     size = "small",
     srcs = ["xla_device_test.py"],
-    tags = ["optonly"],
-    deps = [
-        ":xla_test",
-        "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
-        "//tensorflow/python:platform_test",
-    ],
-)
-
-cuda_py_test(
-    name = "xla_device_gpu_test",
-    size = "small",
-    srcs = ["xla_device_gpu_test.py"],
     additional_deps = [
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client",
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:control_flow_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
     ],
 )
@@ -852,6 +839,7 @@ cuda_py_test(
         "//tensorflow/python:client_testlib",
         "//tensorflow/python:control_flow_ops",
         "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:gradients",
         "//tensorflow/python:layers",
         "//tensorflow/python:math_ops",
@@ -899,7 +887,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:math_ops",
         "//tensorflow/python:random_ops",
         "//tensorflow/python:variables",
@@ -914,7 +902,7 @@ cuda_py_test(
         ":xla_test",
         "//tensorflow/python:array_ops",
         "//tensorflow/python:client_testlib",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:gradients",
         "//tensorflow/python:init_ops",
         "//tensorflow/python:math_ops",
@@ -952,7 +940,7 @@ tf_xla_py_test(
     srcs = ["fake_quant_ops_test.py"],
     deps = [
         ":xla_test",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
@@ -964,7 +952,7 @@ tf_xla_py_test(
     deps = [
         ":xla_test",
         "//tensorflow/python:array_ops",
-        "//tensorflow/python:framework",
+        "//tensorflow/python:framework_for_generated_wrappers",
         "//tensorflow/python:platform_test",
     ],
 )
diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py
deleted file mode 100644 (file)
index 1e30ebd..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test cases for XLA devices."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.client import session as session_lib
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class XlaDeviceGpuTest(test.TestCase):
-
-  def testCopiesToAndFromGpuWork(self):
-    """Tests that copies between GPU and XLA devices work."""
-    if not test.is_gpu_available():
-      return
-
-    with session_lib.Session() as sess:
-      x = array_ops.placeholder(dtypes.float32, [2])
-      with ops.device("GPU"):
-        y = x * 2
-      with ops.device("device:XLA_CPU:0"):
-        z = y * y
-      with ops.device("GPU"):
-        w = y + z
-      result = sess.run(w, {x: [1.5, 0.5]})
-    self.assertAllClose(result, [12., 2.], rtol=1e-3)
-
-
-if __name__ == "__main__":
-  test.main()
index b707bd0..f5c228f 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,33 +18,30 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import numpy as np
-
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
 
-class XlaDeviceTest(XLATestCase):
+class XlaDeviceTest(test.TestCase):
 
   def testCopies(self):
-    """Tests that copies onto and off XLA devices work."""
-    shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],
-              [16384, 1], [1, 16384], [1, 20000, 1, 1]]
-    for dtype in self.numeric_types:
-      for shape in shapes:
-        with self.test_session() as sess:
-          with ops.device("CPU"):
-            x = array_ops.placeholder(dtype, shape)
-          with self.test_scope():
-            y = x + x
-          with ops.device("CPU"):
-            z = array_ops.identity(y)
-
-          inputs = np.random.randint(-100, 100, shape).astype(dtype)
-          result = sess.run(z, {x: inputs})
-        self.assertAllCloseAccordingToType(result, inputs + inputs)
+    """Tests that copies between GPU and XLA devices work."""
+    if not test.is_gpu_available():
+      return
+
+    with session_lib.Session() as sess:
+      x = array_ops.placeholder(dtypes.float32, [2])
+      with ops.device("GPU"):
+        y = x * 2
+      with ops.device("device:XLA_CPU:0"):
+        z = y * y
+      with ops.device("GPU"):
+        w = y + z
+      result = sess.run(w, {x: [1.5, 0.5]})
+    self.assertAllClose(result, [12., 2.], rtol=1e-3)
 
 
 if __name__ == "__main__":
index cd57452..4fca51f 100644 (file)
@@ -325,7 +325,6 @@ tf_cc_test(
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
-        "//tensorflow/compiler/xla:status_macros",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/client:local_client",
         "//tensorflow/compiler/xla/service:cpu_plugin",
index a567226..7054729 100644 (file)
@@ -55,24 +55,18 @@ class RetvalOp : public XlaOpKernel {
       }
 
       XlaContext& tc = XlaContext::Get(ctx);
-      if (tc.resolve_compile_time_constants() &&
-          (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
+      if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) {
         xla::Literal literal;
         OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
         OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
       } else {
-        TensorShape shape = ctx->InputShape(0);
-        TensorShape representation_shape =
-            tc.is_entry_computation()
-                ? tc.RepresentationShape(shape, ctx->input_type(0))
-                : shape;
         // The core from which a return value is returned depends on the core
-        // assignment of the input to the retvalSince we can't change the core
-        // assignment of <input> as this point, we must always introduce a
-        // reshape here, even if the shape does not change.
-        xla::XlaOp reshape =
-            ctx->builder()->Reshape(input, representation_shape.dim_sizes());
-        tc.AddRetval(index_, dtype_, shape, reshape);
+        // assignment of the input to the retval .Since we can't change the core
+        // assignment of <input> as this point, create a tuple/get-tuple-element
+        // combination so that the core will be set on them.
+        auto tuple_elem =
+            ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0);
+        tc.AddRetval(index_, dtype_, tuple_elem);
       }
     }
   }
index 962e534..3d1946c 100644 (file)
@@ -15,9 +15,10 @@ limitations under the License.
 
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 
+#include <deque>
 #include <numeric>
-#include <vector>
 
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
 #include "tensorflow/compiler/tf2xla/dump_graph.h"
 #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
@@ -27,6 +28,7 @@ limitations under the License.
 #include "tensorflow/compiler/tf2xla/type_util.h"
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/common_runtime/executor.h"
@@ -38,6 +40,7 @@ limitations under the License.
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/version.h"
 
 namespace tensorflow {
 namespace {
@@ -108,9 +111,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
   flib_runtime_ = pflr_->GetFLR(device_->name());
 
   // The default variable representation shape is the identity function.
-  if (!options_.shape_representation_fn) {
-    options_.shape_representation_fn = [](const TensorShape& shape,
-                                          DataType type) { return shape; };
+  if (!options_.variable_representation_shape_fn) {
+    options_.variable_representation_shape_fn =
+        [](const TensorShape& shape, DataType type) { return shape; };
   }
 }
 
@@ -227,25 +230,20 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
 
 // Computes the XLA shape for argument 'arg'.
 Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
-                                        bool is_entry_computation,
                                         xla::Shape* xla_shape) {
   switch (arg.kind) {
     case XlaCompiler::Argument::kConstant:
-      LOG(FATAL) << "Unreachable case";
-    case XlaCompiler::Argument::kParameter: {
-      TensorShape shape =
-          is_entry_computation
-              ? options_.shape_representation_fn(arg.shape, arg.type)
-              : arg.shape;
-      return TensorShapeToXLAShape(arg.type, shape, xla_shape);
-    }
+      return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
+                                   xla_shape);
+    case XlaCompiler::Argument::kParameter:
+      return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
     case XlaCompiler::Argument::kResource: {
       TF_RET_CHECK(arg.initialized);
 
       switch (arg.resource_kind) {
         case XlaResource::kVariable: {
           TensorShape representation_shape =
-              options_.shape_representation_fn(arg.shape, arg.type);
+              options_.variable_representation_shape_fn(arg.shape, arg.type);
           return TensorShapeToXLAShape(arg.type, representation_shape,
                                        xla_shape);
         }
@@ -339,25 +337,16 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
 Status BuildComputation(
     const std::vector<XlaCompiler::Argument>& args,
     const std::vector<int>& arg_cores,
-    const std::vector<XlaContext::Retval>& retvals,
+    const std::vector<XlaExpression>& retvals,
     const std::vector<std::unique_ptr<XlaResource>>& resources,
     bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
     xla::XlaComputation* computation, int* num_computation_outputs,
     int* num_nonconst_outputs,
-    std::vector<XlaCompiler::OutputDescription>* outputs,
     std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
   std::vector<xla::XlaOp> elems;
   elems.reserve(retvals.size());
-  for (int i = 0; i < retvals.size(); ++i) {
-    XlaCompiler::OutputDescription& output = (*outputs)[i];
-    output.type = retvals[i].type;
-    output.shape = retvals[i].shape;
-    const XlaExpression& retval = retvals[i].expression;
-    if (retval.has_constant_value()) {
-      output.is_constant = true;
-      output.constant_value = retval.constant_value();
-    } else {
-      output.is_constant = false;
+  for (const XlaExpression& retval : retvals) {
+    if (!retval.has_constant_value()) {
       elems.push_back(retval.handle());
     }
   }
@@ -501,8 +490,8 @@ Status XlaCompiler::BuildArguments(
   std::vector<xla::Shape> arg_shapes(input_mapping->size());
   for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
     // Computes the shapes of non-constant arguments.
-    TF_RETURN_IF_ERROR(XLAShapeForArgument(
-        args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i]));
+    TF_RETURN_IF_ERROR(
+        XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
   }
 
   if (use_tuple_arg) {
@@ -578,8 +567,7 @@ Status XlaCompiler::BuildArguments(
 
   builder->ClearOpMetadata();
 
-  // Fill in the handles in non-constant arguments, and reshape parameters
-  // back to their correct shapes.
+  // Fill in the handles in non-constant arguments.
   VLOG(2) << "XLA computation inputs:";
   for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
     const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
@@ -598,9 +586,7 @@ Status XlaCompiler::BuildArguments(
         break;
       }
       case XlaCompiler::Argument::kParameter:
-        // Reshape parameters back to their correct shapes.
-        arg_expression.set_handle(
-            builder->Reshape(arg_handles[i], arg.shape.dim_sizes()));
+        arg_expression.set_handle(arg_handles[i]);
         break;
       case XlaCompiler::Argument::kConstant:
       case XlaCompiler::Argument::kInvalid:
@@ -675,10 +661,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
       FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
 
   xla::XlaBuilder builder(name);
-  XlaContext* context = new XlaContext(
-      this, &builder, options_.allow_cpu_custom_calls,
-      options.resolve_compile_time_constants, options.is_entry_computation,
-      &options_.shape_representation_fn);
+  XlaContext* context =
+      new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
+                     options.resolve_compile_time_constants,
+                     &options_.variable_representation_shape_fn);
   core::ScopedUnref context_unref(context);
 
   std::vector<XlaExpression> arg_expressions;
@@ -695,22 +681,35 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
   int num_nonconst_outputs;
   int num_computation_outputs;
   result->computation = std::make_shared<xla::XlaComputation>();
-  result->outputs.resize(context->retvals().size());
   TF_RETURN_IF_ERROR(BuildComputation(
       args, arg_cores, context->retvals(), context->resources(),
       options.return_updated_values_for_all_resources, &builder,
       result->computation.get(), &num_computation_outputs,
-      &num_nonconst_outputs, &result->outputs, &result->resource_updates));
+      &num_nonconst_outputs, &result->resource_updates));
 
   VLOG(2) << "Outputs: total: " << context->retvals().size()
           << " nonconstant: " << num_nonconst_outputs;
+  result->outputs.resize(context->retvals().size());
+  for (std::vector<XlaExpression>::size_type i = 0;
+       i < context->retvals().size(); ++i) {
+    const XlaExpression& retval = context->retvals()[i];
+    if (retval.has_constant_value()) {
+      OutputDescription& output = result->outputs[i];
+      output.shape = retval.constant_value().shape();
+      output.is_constant = true;
+      output.constant_value = retval.constant_value();
+    }
+  }
 
-  // Compute the XLA output shape, if there is a computation with non-constant
+  // Compute the output shapes, if there is a computation with non-constant
   // outputs.
-  TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> computation_shape,
-                      client()->GetComputationShape(*result->computation));
+  auto computation_shape = client()->GetComputationShape(*result->computation);
+  if (!computation_shape.ok()) {
+    return computation_shape.status();
+  }
 
-  result->xla_output_shape.Swap(computation_shape->mutable_result());
+  result->xla_output_shape.Swap(
+      computation_shape.ValueOrDie()->mutable_result());
   VLOG(2) << "XLA output shape: "
           << xla::ShapeUtil::HumanString(result->xla_output_shape);
 
@@ -725,6 +724,23 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
   // Tensorflow expects a major-to-minor order of results.
   xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
 
+  // Converts the output shapes to TensorShapes.
+  int computation_output = 0;
+  for (std::vector<XlaExpression>::size_type i = 0;
+       i < context->retvals().size(); ++i) {
+    const XlaExpression& retval = context->retvals()[i];
+    if (!retval.has_constant_value()) {
+      TF_RET_CHECK(computation_output < num_computation_outputs)
+          << "Computation has more outputs than expected";
+      OutputDescription& output = result->outputs[i];
+      output.is_constant = false;
+      TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
+          xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
+                                               computation_output),
+          &output.shape));
+      ++computation_output;
+    }
+  }
   return Status::OK();
 }
 
index 621fbc1..ca6cd82 100644 (file)
@@ -67,15 +67,6 @@ class XlaContext;
 // _Retval values are ordered by _Retval index, whereas kResource values are
 // ordered by the original _Arg position of the variable.
 //
-// If a shape representation function is provided as part of
-// XlaCompiler::CompileOptions, kParameter arguments and return values to an
-// entry computation will be reshaped in accordance to the shape function.
-// Arguments and return values to a non-entry computation are not reshaped.
-// Variable resource arguments are passed and returned in reshaped form, even
-// for non-entry computations. This feature allows TensorFlow to keep on-device
-// tensors with a different shape to their representation inside the XLA
-// computation.
-//
 // In both inputs and outputs, kResource values are placed the end. When
 // emitting While loop bodies, we must ensure that the loop body has
 // identical input and output signatures. By moving variable values
@@ -180,7 +171,7 @@ class XlaCompiler {
   };
 
   struct OutputDescription {
-    // Type and shape of the output. The shape is the unflattened shape.
+    // Type and shape of the output.
     DataType type;
     TensorShape shape;
 
@@ -215,12 +206,10 @@ class XlaCompiler {
     // original arguments, and are not necessarily in the same order.)
     std::vector<int> input_mapping;
 
-    // Input shapes of the computation. If we are flattening inputs, these are
-    // the flattened shapes.
+    // Input shapes of the computation.
     std::vector<xla::Shape> xla_input_shapes;
 
-    // Output shape in XLA format. The output shape is always a tuple. If we
-    // are flattening outputs, these are the flattened shapes.
+    // Output shape in XLA format. The output shape is always a tuple.
     xla::Shape xla_output_shape;
 
     // TensorFlow shapes of outputs, together with the values of any
@@ -241,8 +230,6 @@ class XlaCompiler {
     std::shared_ptr<xla::XlaComputation> computation;
   };
 
-  typedef std::function<TensorShape(const TensorShape&, DataType)>
-      ShapeRepresentationFn;
   struct Options {
     // Name of the compilation device to use. Needs to be live only during
     // XlaCompiler's constructor.
@@ -263,7 +250,8 @@ class XlaCompiler {
     // If set, the XLA representation of variables represented to XLA as the
     // shape given by this shape function. Variables are reshaped to this shape
     // on write, and reshaped to their original shape on read.
-    ShapeRepresentationFn shape_representation_fn;
+    std::function<TensorShape(const TensorShape&, DataType)>
+        variable_representation_shape_fn;
 
     // If not nullptr, populate_resource_manager is called with the
     // compilation device's resource manager when the compilation
@@ -312,8 +300,7 @@ class XlaCompiler {
   // Returns the shape of the XLA parameter for an argument 'arg'.
   // See the class comment for more details about the argument passing
   // convention.
-  Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
-                             xla::Shape* xla_shape);
+  Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
 
   // Retrieves the channel handle associated with `key`. Allocates
   // a new channel handle if none exists.
index 5670545..4382ffe 100644 (file)
@@ -25,7 +25,6 @@ limitations under the License.
 #include "tensorflow/compiler/xla/client/local_client.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/framework/common_shape_fns.h"
@@ -751,7 +750,10 @@ TEST_F(XlaCompilerTest, Variables) {
   EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 }
 
-xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
+// Tests a simple graph that reads and writes a variable, with a
+// variable_representation_shape_fn passed to the compiler that flattens all
+// variable tensors to vectors.
+TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
   Scope scope = Scope::NewRootScope().ExitOnError();
   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
@@ -762,15 +764,7 @@ xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
-  TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
-  return std::move(graph);
-}
-
-// Tests a simple graph that reads and writes a variable, with a
-// shape_representation_fn passed to the compiler that flattens all
-// variable tensors to vectors.
-TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
+  TF_ASSERT_OK(scope.ToGraph(graph.get()));
 
   // Builds a description of the arguments.
   std::vector<XlaCompiler::Argument> args(2);
@@ -785,33 +779,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
 
   // Compiles the graph.
   XlaCompiler::Options options = DefaultOptions();
-  options.shape_representation_fn = [](const TensorShape& shape,
-                                       DataType type) {
+  options.variable_representation_shape_fn = [](const TensorShape& shape,
+                                                DataType type) {
     return TensorShape({shape.num_elements()});
   };
   XlaCompiler compiler(options);
 
-  XlaCompiler::CompileOptions compile_options;
-  compile_options.is_entry_computation = false;  // Only reshape variables.
-
   XlaCompiler::CompilationResult result;
-  TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
-                                     args, &result));
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
-                          client_->GetComputationShape(*result.computation));
-
-  ASSERT_EQ(program_shape->parameters_size(), 2);
-  EXPECT_TRUE(
-      xla::ShapeUtil::Compatible(program_shape->parameters(0),
-                                 xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
-  EXPECT_TRUE(xla::ShapeUtil::Compatible(
-      program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
-  EXPECT_TRUE(xla::ShapeUtil::Compatible(
-      program_shape->result(),
-      xla::ShapeUtil::MakeTupleShape(
-          {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
-           xla::ShapeUtil::MakeShape(xla::S32, {4})})));
+  TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+                                     std::move(graph), args, &result));
 
   // Tests that the generated computation works.
   std::unique_ptr<xla::Literal> param0_literal =
@@ -839,74 +815,5 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
   EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 }
 
-TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
-
-  // Builds a description of the arguments.
-  std::vector<XlaCompiler::Argument> args(2);
-  args[0].kind = XlaCompiler::Argument::kParameter;
-  args[0].type = DT_INT32;
-  args[0].shape = TensorShape({2, 2});
-  args[1].kind = XlaCompiler::Argument::kResource;
-  args[1].resource_kind = XlaResource::kVariable;
-  args[1].initialized = true;
-  args[1].type = DT_INT32;
-  args[1].shape = TensorShape({2, 2});
-
-  // Compiles the graph.
-  XlaCompiler::Options options = DefaultOptions();
-  options.shape_representation_fn = [](const TensorShape& shape,
-                                       DataType type) {
-    return TensorShape({shape.num_elements()});
-  };
-  XlaCompiler compiler(options);
-
-  XlaCompiler::CompileOptions compile_options;
-  compile_options.is_entry_computation = true;  // Reshape args and retvals.
-
-  XlaCompiler::CompilationResult result;
-  TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
-                                     args, &result));
-
-  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
-                          client_->GetComputationShape(*result.computation));
-
-  ASSERT_EQ(program_shape->parameters_size(), 2);
-  EXPECT_TRUE(xla::ShapeUtil::Compatible(
-      program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
-  EXPECT_TRUE(xla::ShapeUtil::Compatible(
-      program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
-  EXPECT_TRUE(xla::ShapeUtil::Compatible(
-      program_shape->result(),
-      xla::ShapeUtil::MakeTupleShape(
-          {xla::ShapeUtil::MakeShape(xla::S32, {4}),
-           xla::ShapeUtil::MakeShape(xla::S32, {4})})));
-
-  // Tests that the generated computation works.
-  std::unique_ptr<xla::Literal> param0_literal =
-      xla::Literal::CreateR1<int32>({4, 55, 1, -3});
-  std::unique_ptr<xla::Literal> param1_literal =
-      xla::Literal::CreateR1<int32>({22, 11, 33, 404});
-  std::unique_ptr<xla::GlobalData> param0_data =
-      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
-  std::unique_ptr<xla::GlobalData> param1_data =
-      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
-
-  std::unique_ptr<xla::GlobalData> actual =
-      client_
-          ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
-          .ConsumeValueOrDie();
-  std::unique_ptr<xla::Literal> actual_literal =
-      client_->Transfer(*actual).ConsumeValueOrDie();
-
-  std::unique_ptr<xla::Literal> expected0 =
-      xla::Literal::CreateR1<int32>({27, 67, 35, 402});
-  std::unique_ptr<xla::Literal> expected1 =
-      xla::Literal::CreateR1<int32>({26, 66, 34, 401});
-  std::unique_ptr<xla::Literal> expected_literal =
-      xla::Literal::MakeTuple({expected0.get(), expected1.get()});
-  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
-}
-
 }  // namespace
 }  // namespace tensorflow
index 098072d..3dd2d18 100644 (file)
@@ -65,30 +65,26 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
 XlaContext::XlaContext(
     XlaCompiler* compiler, xla::XlaBuilder* builder,
     bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
-    bool is_entry_computation,
     const std::function<TensorShape(const TensorShape&, DataType)>*
-        shape_representation_fn)
+        variable_representation_shape_fn)
     : compiler_(compiler),
       builder_(builder),
       allow_cpu_custom_calls_(allow_cpu_custom_calls),
       resolve_compile_time_constants_(resolve_compile_time_constants),
-      is_entry_computation_(is_entry_computation),
-      shape_representation_fn_(shape_representation_fn) {}
+      variable_representation_shape_fn_(variable_representation_shape_fn) {}
 
 string XlaContext::DebugString() { return "TLA JIT context"; }
 
 // This is called by the Retval Op to associate a computed value
 // with a specific return value of the subgraph.
 void XlaContext::AddRetval(int retval_index, DataType type,
-                           const TensorShape& shape, const xla::XlaOp& handle) {
+                           const xla::XlaOp& handle) {
   VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
   // Add the return value to the list being built up.
   if (retvals_.size() <= retval_index) {
     retvals_.resize(retval_index + 1);
   }
-  XlaExpression e;
-  e.set_handle(handle);
-  retvals_[retval_index] = Retval{type, shape, e};
+  retvals_[retval_index].set_handle(handle);
 }
 
 Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
@@ -98,11 +94,13 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
   if (retvals_.size() <= retval_index) {
     retvals_.resize(retval_index + 1);
   }
-  Tensor value;
-  TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
-  XlaExpression e;
-  e.set_constant_value(value);
-  retvals_[retval_index] = Retval{dtype, value.shape(), e};
+  if (resolve_compile_time_constants_) {
+    Tensor value;
+    TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
+    retvals_[retval_index].set_constant_value(std::move(value));
+  } else {
+    retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
+  }
   return Status::OK();
 }
 
@@ -119,9 +117,9 @@ Status XlaContext::CreateResource(
   return Status::OK();
 }
 
-TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
-                                            DataType type) const {
-  return (*shape_representation_fn_)(shape, type);
+TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
+                                                    DataType type) const {
+  return (*variable_representation_shape_fn_)(shape, type);
 }
 
 const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
index 3ad2b2e..1136ffe 100644 (file)
@@ -42,13 +42,11 @@ class XlaContext : public ResourceBase {
   static XlaContext& Get(const OpKernelContext* ctx);
   static XlaContext& Get(const XlaOpKernelContext* ctx);
 
-  // Creates a new XlaContext. See the documentation on the class data fields
-  // for descriptions of the arguments.
+  // Creates a new XlaContext.
   XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
              bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
-             bool is_entry_computation,
              const std::function<TensorShape(const TensorShape&, DataType)>*
-                 shape_representation_fn);
+                 variable_representation_shape_fn);
 
   // Virtual method defined by ResourceBase.
   string DebugString() override;
@@ -60,26 +58,14 @@ class XlaContext : public ResourceBase {
 
   bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
 
-  bool resolve_compile_time_constants() const {
-    return resolve_compile_time_constants_;
-  }
-  bool is_entry_computation() const { return is_entry_computation_; }
-
   const std::vector<XlaExpression>& args() const { return args_; }
   void set_args(std::vector<XlaExpression> args);
 
-  struct Retval {
-    DataType type;
-    TensorShape shape;
-    // An XlaExpression representing the Retval's value.
-    XlaExpression expression;
-  };
-  const std::vector<Retval>& retvals() { return retvals_; }
+  const std::vector<XlaExpression>& retvals() { return retvals_; }
 
   // This is called by the Retval Op to associate a computed value
   // with a specific return value of the subgraph.
-  void AddRetval(int retval_index, DataType type, const TensorShape& shape,
-                 const xla::XlaOp& handle);
+  void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle);
 
   // As for Retval, but for return values that are compile-time constants.
   Status AddConstRetval(int retval_index, DataType dtype,
@@ -100,9 +86,9 @@ class XlaContext : public ResourceBase {
   }
 
   // Returns the XLA shape to be used to represent a variable of TF `shape`
-  // and `type`, or of an argument or return value of a top-level computation.
-  TensorShape RepresentationShape(const TensorShape& shape,
-                                  DataType type) const;
+  // and `type`.
+  TensorShape VariableRepresentationShape(const TensorShape& shape,
+                                          DataType type) const;
 
   // Get an XLA lambda to compute Max. This is cached in the
   // XlaContext since it may be used by multiple Ops. There is a
@@ -145,19 +131,15 @@ class XlaContext : public ResourceBase {
   std::vector<XlaExpression> args_;
 
   // Return values of the Tensorflow graph, indexed by _Retval index.
-  std::vector<Retval> retvals_;
+  std::vector<XlaExpression> retvals_;
 
   // Holds ownership of resources. The resources are not ordered.
   std::vector<std::unique_ptr<XlaResource>> resources_;
 
-  // Is this a top-level computation, or an inner computation (e.g., a while
-  // body)?
-  const bool is_entry_computation_;
-
   // A function that describes how variable shapes should be represented
   // in XLA. Variable values will be reshaped to this shape. Must be non-null.
   const std::function<TensorShape(const TensorShape&, DataType)>*
-      shape_representation_fn_;
+      variable_representation_shape_fn_;
 
   // Cache of prebuilt computations indexed by their type.
   using ComputationMap = std::map<DataType, xla::XlaComputation>;
index 76c68d8..2b65f4d 100644 (file)
@@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
   }
 
   XlaContext& xla_context = XlaContext::Get(context_);
-  TensorShape representation_shape =
-      xla_context.RepresentationShape(variable->shape(), variable->type());
+  TensorShape representation_shape = xla_context.VariableRepresentationShape(
+      variable->shape(), variable->type());
   if (representation_shape == variable->shape()) {
     *value = variable->value();
   } else {
@@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
 
   XlaContext& xla_context = XlaContext::Get(context_);
   TensorShape representation_shape =
-      xla_context.RepresentationShape(shape, type);
+      xla_context.VariableRepresentationShape(shape, type);
   if (shape != representation_shape) {
     handle = builder()->Reshape(handle, representation_shape.dim_sizes());
   }