[TF:XLA] Add a hook to allow reshaping of TensorFlow variables when storing them...
authorPeter Hawkins <phawkins@google.com>
Wed, 14 Feb 2018 22:49:23 +0000 (14:49 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 22:53:40 +0000 (14:53 -0800)
PiperOrigin-RevId: 185748660

tensorflow/compiler/tf2xla/BUILD
tensorflow/compiler/tf2xla/literal_util.cc
tensorflow/compiler/tf2xla/literal_util.h
tensorflow/compiler/tf2xla/xla_compiler.cc
tensorflow/compiler/tf2xla/xla_compiler.h
tensorflow/compiler/tf2xla/xla_compiler_test.cc
tensorflow/compiler/tf2xla/xla_context.cc
tensorflow/compiler/tf2xla/xla_context.h
tensorflow/compiler/tf2xla/xla_op_kernel.cc
tensorflow/compiler/tf2xla/xla_op_kernel.h

index 3c7dfef03dfb5d86dd63fd4aa84ad56081833035..fb82c2601c432cee425a46a3b6dc2c55febeda87 100644 (file)
@@ -312,6 +312,7 @@ tf_cc_test(
         "//tensorflow/cc:cc_ops",
         "//tensorflow/cc:function_ops",
         "//tensorflow/cc:ops",
+        "//tensorflow/cc:resource_variable_ops",
         "//tensorflow/compiler/tf2xla/kernels:xla_ops",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
index fcbd157c6191655865d5e250fdf71338780bc2a6..2c3cd658e0462368ac0b51938979b7a6815a7574 100644 (file)
@@ -40,20 +40,20 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
   return Status::OK();
 }
 
-Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
-                           Tensor* host_tensor) {
+Status CopyLiteralToHostTensor(const xla::Literal& literal,
+                               Tensor* host_tensor) {
+  TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) &&
+               xla::ShapeUtil::ElementsIn(literal.shape()) ==
+                   host_tensor->NumElements());
   xla::PrimitiveType primitive_type;
-  TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(target_type, &primitive_type));
+  TF_RETURN_IF_ERROR(
+      DataTypeToPrimitiveType(host_tensor->dtype(), &primitive_type));
   if (literal.shape().element_type() != primitive_type) {
     return errors::InvalidArgument(
         "Cannot convert literal of type ",
         xla::PrimitiveType_Name(literal.shape().element_type()),
-        " to tensor of type ", DataTypeString(target_type));
+        " to tensor of type ", DataTypeString(host_tensor->dtype()));
   }
-
-  TensorShape shape;
-  TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape));
-  *host_tensor = Tensor(target_type, shape);
   size_t total_bytes = host_tensor->TotalBytes();
   if (total_bytes > 0) {
     const void* src_ptr = literal.untyped_data();
@@ -63,4 +63,12 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
   return Status::OK();
 }
 
+Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
+                           Tensor* host_tensor) {
+  TensorShape shape;
+  TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape));
+  *host_tensor = Tensor(target_type, shape);
+  return CopyLiteralToHostTensor(literal, host_tensor);
+}
+
 }  // namespace tensorflow
index fe08e83c2391a8b24696961cacfd909d46e49e7d..f283b0236811f8d52e8fe2982a74c11c92cd20d8 100644 (file)
@@ -29,7 +29,8 @@ namespace tensorflow {
 // unsupported type.
 Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
 
-// Copies 'literal' to 'host_tensor', which is allocated of type <target_type>.
+// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
+// type <target_type>.
 // Fails if the literal's primitive type !=
 // DataTypeToPrimitiveType(target_type). Note that <target_type> is not
 // derivable from the type of <literal>, because multiple tensorflow types map
@@ -38,6 +39,12 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
 Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
                            Tensor* host_tensor);
 
+// Copies the contents of 'literal' to a previously allocated tensor
+// 'host_tensor'. The tensor and the literal must have the same number of
+// elements and the same type.
+Status CopyLiteralToHostTensor(const xla::Literal& literal,
+                               Tensor* host_tensor);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
index 59e88304422eaeaaf3f63cc4d476a8ec7ce95623..15bba46ac62a97592656942afc767a303c9b97f3 100644 (file)
@@ -109,6 +109,12 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
 
   local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
   flib_runtime_ = pflr_->GetFLR(device_->name());
+
+  // The default variable representation shape is the identity function.
+  if (!options_.variable_representation_shape_fn) {
+    options_.variable_representation_shape_fn =
+        [](const TensorShape& shape, DataType type) { return shape; };
+  }
 }
 
 XlaCompiler::~XlaCompiler() = default;
@@ -223,8 +229,8 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
 }
 
 // Computes the XLA shape for argument 'arg'.
-/*static*/ Status XlaCompiler::XLAShapeForArgument(
-    const XlaCompiler::Argument& arg, xla::Shape* xla_shape) {
+Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
+                                        xla::Shape* xla_shape) {
   switch (arg.kind) {
     case XlaCompiler::Argument::kConstant:
       return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
@@ -235,8 +241,12 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
       TF_RET_CHECK(arg.initialized);
 
       switch (arg.resource_kind) {
-        case XlaResource::kVariable:
-          return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
+        case XlaResource::kVariable: {
+          TensorShape representation_shape =
+              options_.variable_representation_shape_fn(arg.shape, arg.type);
+          return TensorShapeToXLAShape(arg.type, representation_shape,
+                                       xla_shape);
+        }
         case XlaResource::kTensorArray: {
           if (arg.tensor_array_size < 0) {
             return errors::InvalidArgument(
@@ -310,16 +320,116 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
   return Status::OK();
 }
 
+// Builds the XLA computation.
+//
+// `retvals` is the list of retvals produced by _Retval operators, in index
+// order. `variable_map` is a map from variable ID numbers to XlaOpContext
+// variable states, generated by the symbolic evaluation.
+// If `return_updated_values_for_all_resources` is true, all resources will be
+// included in `resource_updates`, regardless of whether their value changed.
+// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
+// Sets `*resource_updates` to a description of resources whose values are
+// written by the computation; the variable writes are the last
+// `resource_updates.size()` return values from the computation. Each entry in
+// `resource_updates` is a (input_index, type) pair, where `input_index` is the
+// index of a resource variable argument to the computation, and `type` is the
+// type of the final output.
+Status BuildComputation(
+    const std::vector<XlaCompiler::Argument>& args,
+    const std::vector<int>& arg_cores,
+    const std::vector<XlaExpression>& retvals,
+    const std::vector<std::unique_ptr<XlaResource>>& resources,
+    bool return_updated_values_for_all_resources,
+    xla::ComputationBuilder* builder, xla::Computation* computation,
+    int* num_computation_outputs, int* num_nonconst_outputs,
+    std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
+  std::vector<xla::ComputationDataHandle> elems;
+  elems.reserve(retvals.size());
+  for (const XlaExpression& retval : retvals) {
+    if (!retval.has_constant_value()) {
+      elems.push_back(retval.handle());
+    }
+  }
+  *num_nonconst_outputs = elems.size();
+
+  // Add return values for resources whose values have changed.
+  std::vector<const XlaResource*> arg_resources;
+  arg_resources.reserve(resources.size());
+  for (const auto& resource : resources) {
+    if (resource->arg_num() >= 0) {
+      arg_resources.push_back(resource.get());
+    }
+  }
+  std::sort(arg_resources.begin(), arg_resources.end(),
+            [](const XlaResource* a, const XlaResource* b) {
+              return a->arg_num() < b->arg_num();
+            });
+
+  for (const XlaResource* resource : arg_resources) {
+    const XlaCompiler::Argument& arg = args[resource->arg_num()];
+    const int core = arg_cores[resource->arg_num()];
+    DCHECK_LT(resource->arg_num(), arg_cores.size());
+    bool modified =
+        resource->value().handle() != resource->initial_value().handle();
+    // TensorArray gradients were modified if their values changed or there are
+    // any newly created gradients.
+    for (const auto& grad : resource->tensor_array_gradients()) {
+      modified = modified ||
+                 grad.second->value().handle() !=
+                     grad.second->initial_value().handle() ||
+                 arg.tensor_array_gradients.count(grad.first) == 0;
+    }
+    if (return_updated_values_for_all_resources || modified) {
+      resource_updates->emplace_back();
+      XlaCompiler::ResourceUpdate& update = resource_updates->back();
+      update.input_index = resource->arg_num();
+      update.type = resource->type();
+      update.shape = resource->shape();
+      update.modified = modified;
+      for (const auto& grad : resource->tensor_array_gradients()) {
+        update.tensor_array_gradients_accessed.insert(grad.first);
+      }
+
+      // Request that the value be returned on a specific core.
+      xla::ScopedShardingAssignment assign_sharding(
+          builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
+                              : xla::sharding_builder::AssignDevice(core));
+
+      xla::ComputationDataHandle handle;
+      TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
+
+      // Since we can't change the sharding metadata of <value> as this point,
+      // create a tuple/get-tuple-element combination so that sharding
+      // assignment will be placed on this value, which will cause the resource
+      // update to be returned from the same device that provided the resource.
+      handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
+
+      elems.push_back(handle);
+    }
+  }
+
+  *num_computation_outputs = elems.size();
+
+  // Builds the XLA computation.
+  builder->Tuple(elems);
+  xla::StatusOr<xla::Computation> computation_status = builder->Build();
+  if (!computation_status.ok()) {
+    return computation_status.status();
+  }
+  *computation = computation_status.ConsumeValueOrDie();
+  return Status::OK();
+}
+
+}  // namespace
+
 // Builds XLA computations for each of the arguments to the computation.
 // `args` are the arguments to the computation.
-Status BuildArguments(const Graph& graph,
-                      const std::vector<XlaCompiler::Argument>& args,
-                      bool use_tuple_arg, xla::ComputationBuilder* builder,
-                      XlaContext* context, std::vector<int>* arg_cores,
-                      std::vector<XlaExpression>* arg_expressions,
-                      std::vector<int>* input_mapping,
-                      std::vector<xla::Shape>* input_shapes,
-                      bool is_entry_computation) {
+Status XlaCompiler::BuildArguments(
+    const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
+    bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context,
+    std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
+    std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
+    bool is_entry_computation) {
   arg_expressions->resize(args.size());
   *arg_cores = std::vector<int>(args.size(), -1);
 
@@ -374,8 +484,8 @@ Status BuildArguments(const Graph& graph,
   std::vector<xla::Shape> arg_shapes(input_mapping->size());
   for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
     // Computes the shapes of non-constant arguments.
-    TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument(
-        args[(*input_mapping)[i]], &arg_shapes[i]));
+    TF_RETURN_IF_ERROR(
+        XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
   }
 
   if (use_tuple_arg) {
@@ -472,108 +582,6 @@ Status BuildArguments(const Graph& graph,
   return Status::OK();
 }
 
-// Builds the XLA computation.
-//
-// `retvals` is the list of retvals produced by _Retval operators, in index
-// order. `variable_map` is a map from variable ID numbers to XlaOpContext
-// variable states, generated by the symbolic evaluation.
-// If `return_updated_values_for_all_resources` is true, all resources will be
-// included in `resource_updates`, regardless of whether their value changed.
-// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
-// Sets `*resource_updates` to a description of resources whose values are
-// written by the computation; the variable writes are the last
-// `resource_updates.size()` return values from the computation. Each entry in
-// `resource_updates` is a (input_index, type) pair, where `input_index` is the
-// index of a resource variable argument to the computation, and `type` is the
-// type of the final output.
-Status BuildComputation(
-    const std::vector<XlaCompiler::Argument>& args,
-    const std::vector<int>& arg_cores,
-    const std::vector<XlaExpression>& retvals,
-    const std::vector<std::unique_ptr<XlaResource>>& resources,
-    bool return_updated_values_for_all_resources,
-    xla::ComputationBuilder* builder, xla::Computation* computation,
-    int* num_computation_outputs, int* num_nonconst_outputs,
-    std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
-  std::vector<xla::ComputationDataHandle> elems;
-  elems.reserve(retvals.size());
-  for (const XlaExpression& retval : retvals) {
-    if (!retval.has_constant_value()) {
-      elems.push_back(retval.handle());
-    }
-  }
-  *num_nonconst_outputs = elems.size();
-
-  // Add return values for resources whose values have changed.
-  std::vector<const XlaResource*> arg_resources;
-  arg_resources.reserve(resources.size());
-  for (const auto& resource : resources) {
-    if (resource->arg_num() >= 0) {
-      arg_resources.push_back(resource.get());
-    }
-  }
-  std::sort(arg_resources.begin(), arg_resources.end(),
-            [](const XlaResource* a, const XlaResource* b) {
-              return a->arg_num() < b->arg_num();
-            });
-
-  for (const XlaResource* resource : arg_resources) {
-    const XlaCompiler::Argument& arg = args[resource->arg_num()];
-    const int core = arg_cores[resource->arg_num()];
-    DCHECK_LT(resource->arg_num(), arg_cores.size());
-    bool modified =
-        resource->value().handle() != resource->initial_value().handle();
-    // TensorArray gradients were modified if their values changed or there are
-    // any newly created gradients.
-    for (const auto& grad : resource->tensor_array_gradients()) {
-      modified = modified ||
-                 grad.second->value().handle() !=
-                     grad.second->initial_value().handle() ||
-                 arg.tensor_array_gradients.count(grad.first) == 0;
-    }
-    if (return_updated_values_for_all_resources || modified) {
-      resource_updates->emplace_back();
-      XlaCompiler::ResourceUpdate& update = resource_updates->back();
-      update.input_index = resource->arg_num();
-      update.type = resource->type();
-      update.shape = resource->shape();
-      update.modified = modified;
-      for (const auto& grad : resource->tensor_array_gradients()) {
-        update.tensor_array_gradients_accessed.insert(grad.first);
-      }
-
-      // Request that the value be returned on a specific core.
-      xla::ScopedShardingAssignment assign_sharding(
-          builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
-                              : xla::sharding_builder::AssignDevice(core));
-
-      xla::ComputationDataHandle handle;
-      TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
-
-      // Since we can't change the sharding metadata of <value> as this point,
-      // create a tuple/get-tuple-element combination so that sharding
-      // assignment will be placed on this value, which will cause the resource
-      // update to be returned from the same device that provided the resource.
-      handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
-
-      elems.push_back(handle);
-    }
-  }
-
-  *num_computation_outputs = elems.size();
-
-  // Builds the XLA computation.
-  builder->Tuple(elems);
-  xla::StatusOr<xla::Computation> computation_status = builder->Build();
-  if (!computation_status.ok()) {
-    return computation_status.status();
-  }
-  *computation = computation_status.ConsumeValueOrDie();
-  return Status::OK();
-}
-
-}  // namespace
-
 Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
                                  string const& name,
                                  std::unique_ptr<Graph> graph,
@@ -598,7 +606,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
   xla::ComputationBuilder builder(client(), name);
   XlaContext* context =
       new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
-                     options.resolve_compile_time_constants);
+                     options.resolve_compile_time_constants,
+                     &options_.variable_representation_shape_fn);
   core::ScopedUnref context_unref(context);
 
   std::vector<XlaExpression> arg_expressions;
index b86c82c0ab5ce379d35a13043857f459199e2ad2..c4449bc4be06daff856eff70c6d89be6ddbcf0ee 100644 (file)
@@ -29,6 +29,9 @@ limitations under the License.
 #include "tensorflow/core/public/version.h"
 
 namespace tensorflow {
+
+class XlaContext;
+
 // The XlaCompiler class is responsible for compilation of a self-contained
 // subgraph of a TensorFlow computation using the XLA linear algebra runtime.
 // It does a symbolic execution of the graph starting from specific input
@@ -239,6 +242,12 @@ class XlaCompiler {
     // for CPU.
     bool allow_cpu_custom_calls = false;
 
+    // If set, the XLA representation of variables represented to XLA as the
+    // shape given by this shape function. Variables are reshaped to this shape
+    // on write, and reshaped to their original shape on read.
+    std::function<TensorShape(const TensorShape&, DataType)>
+        variable_representation_shape_fn;
+
     // If not nullptr, populate_resource_manager is called with the
     // compilation device's resource manager when the compilation
     // device is created, and can be used to create metadata objects
@@ -278,7 +287,7 @@ class XlaCompiler {
   // Returns the shape of the XLA parameter for an argument 'arg'.
   // See the class comment for more details about the argument passing
   // convention.
-  static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
+  Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
 
   // Retrieves the channel handle associated with `key`. Allocates
   // a new channel handle if none exists.
@@ -299,6 +308,17 @@ class XlaCompiler {
   // Returns the optimized graph object in this function body.
   std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
 
+  // Builds XLA computations for each of the arguments to the computation.
+  // `args` are the arguments to the computation.
+  Status BuildArguments(const Graph& graph,
+                        const std::vector<XlaCompiler::Argument>& args,
+                        bool use_tuple_arg, xla::ComputationBuilder* builder,
+                        XlaContext* context, std::vector<int>* arg_cores,
+                        std::vector<XlaExpression>* arg_expressions,
+                        std::vector<int>* input_mapping,
+                        std::vector<xla::Shape>* input_shapes,
+                        bool is_entry_computation);
+
   // Graph compiler needs to know how to get an optimized graph from a function
   // body.
   friend class GraphCompiler;
index 65de4dbad75b7fb76a041bc799fc31dc5cb80d74..a18eeacd41808884fac9ec5d617cb0d274ea27d8 100644 (file)
@@ -17,6 +17,7 @@ limitations under the License.
 #include "tensorflow/cc/framework/ops.h"
 #include "tensorflow/cc/ops/data_flow_ops.h"
 #include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
 #include "tensorflow/cc/ops/standard_ops.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -683,5 +684,128 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
       << status.error_message();
 }
 
+// Tests a simple graph that reads and writes a variable.
+TEST_F(XlaCompilerTest, Variables) {
+  Scope scope = Scope::NewRootScope().ExitOnError();
+  auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+  auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
+  auto write = ops::AssignAddVariableOp(scope, var, a);
+  auto read = ops::ReadVariableOp(
+      scope.WithControlDependencies(std::vector<Operation>{write}), var,
+      DT_INT32);
+  auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
+  auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+  // Builds a description of the arguments.
+  std::vector<XlaCompiler::Argument> args(2);
+  args[0].kind = XlaCompiler::Argument::kParameter;
+  args[0].type = DT_INT32;
+  args[0].shape = TensorShape({2});
+  args[1].kind = XlaCompiler::Argument::kResource;
+  args[1].resource_kind = XlaResource::kVariable;
+  args[1].initialized = true;
+  args[1].type = DT_INT32;
+  args[1].shape = TensorShape({2});
+
+  // Compiles the graph.
+  XlaCompiler compiler(DefaultOptions());
+
+  XlaCompiler::CompilationResult result;
+  TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+                                     std::move(graph), args, &result));
+
+  // Tests that the generated computation works.
+  std::unique_ptr<xla::Literal> param0_literal =
+      xla::Literal::CreateR1<int32>({7, 42});
+  std::unique_ptr<xla::Literal> param1_literal =
+      xla::Literal::CreateR1<int32>({-3, 101});
+  std::unique_ptr<xla::GlobalData> param0_data =
+      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+  std::unique_ptr<xla::GlobalData> param1_data =
+      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+  std::unique_ptr<xla::GlobalData> actual =
+      client_
+          ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+          .ConsumeValueOrDie();
+  std::unique_ptr<xla::Literal> actual_literal =
+      client_->Transfer(*actual).ConsumeValueOrDie();
+
+  std::unique_ptr<xla::Literal> expected0 =
+      xla::Literal::CreateR1<int32>({5, 144});
+  std::unique_ptr<xla::Literal> expected1 =
+      xla::Literal::CreateR1<int32>({4, 143});
+  std::unique_ptr<xla::Literal> expected_literal =
+      xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+  xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+}
+
+// Tests a simple graph that reads and writes a variable, with a
+// variable_representation_shape_fn passed to the compiler that flattens all
+// variable tensors to vectors.
+TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
+  Scope scope = Scope::NewRootScope().ExitOnError();
+  auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+  auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
+  auto write = ops::AssignAddVariableOp(scope, var, a);
+  auto read = ops::ReadVariableOp(
+      scope.WithControlDependencies(std::vector<Operation>{write}), var,
+      DT_INT32);
+  auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
+  auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
+  std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+  TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+  // Builds a description of the arguments.
+  std::vector<XlaCompiler::Argument> args(2);
+  args[0].kind = XlaCompiler::Argument::kParameter;
+  args[0].type = DT_INT32;
+  args[0].shape = TensorShape({2, 2});
+  args[1].kind = XlaCompiler::Argument::kResource;
+  args[1].resource_kind = XlaResource::kVariable;
+  args[1].initialized = true;
+  args[1].type = DT_INT32;
+  args[1].shape = TensorShape({2, 2});
+
+  // Compiles the graph.
+  XlaCompiler::Options options = DefaultOptions();
+  options.variable_representation_shape_fn = [](const TensorShape& shape,
+                                                DataType type) {
+    return TensorShape({shape.num_elements()});
+  };
+  XlaCompiler compiler(options);
+
+  XlaCompiler::CompilationResult result;
+  TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+                                     std::move(graph), args, &result));
+
+  // Tests that the generated computation works.
+  std::unique_ptr<xla::Literal> param0_literal =
+      xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}});
+  std::unique_ptr<xla::Literal> param1_literal =
+      xla::Literal::CreateR1<int32>({22, 11, 33, 404});
+  std::unique_ptr<xla::GlobalData> param0_data =
+      client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+  std::unique_ptr<xla::GlobalData> param1_data =
+      client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+  std::unique_ptr<xla::GlobalData> actual =
+      client_
+          ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+          .ConsumeValueOrDie();
+  std::unique_ptr<xla::Literal> actual_literal =
+      client_->Transfer(*actual).ConsumeValueOrDie();
+
+  std::unique_ptr<xla::Literal> expected0 =
+      xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}});
+  std::unique_ptr<xla::Literal> expected1 =
+      xla::Literal::CreateR1<int32>({26, 66, 34, 401});
+  std::unique_ptr<xla::Literal> expected_literal =
+      xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+  xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+}
+
 }  // namespace
 }  // namespace tensorflow
index 73878955e3fd54c103c0b07faf7f5ee5bcd84de0..8423921086fec1cf534cf613102fc3839035cb85 100644 (file)
@@ -62,13 +62,16 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
   args_ = std::move(args);
 }
 
-XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
-                       bool allow_cpu_custom_calls,
-                       bool resolve_compile_time_constants)
+XlaContext::XlaContext(
+    XlaCompiler* compiler, xla::ComputationBuilder* builder,
+    bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
+    const std::function<TensorShape(const TensorShape&, DataType)>*
+        variable_representation_shape_fn)
     : compiler_(compiler),
       builder_(builder),
       allow_cpu_custom_calls_(allow_cpu_custom_calls),
-      resolve_compile_time_constants_(resolve_compile_time_constants) {}
+      resolve_compile_time_constants_(resolve_compile_time_constants),
+      variable_representation_shape_fn_(variable_representation_shape_fn) {}
 
 string XlaContext::DebugString() { return "TLA JIT context"; }
 
@@ -115,6 +118,11 @@ Status XlaContext::CreateResource(
   return Status::OK();
 }
 
+TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
+                                                    DataType type) const {
+  return (*variable_representation_shape_fn_)(shape, type);
+}
+
 const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
   return LookupOrCreate(type, &max_func_, [this, type] {
     const string type_string = DataTypeString(type);
index fac0352ae81e24597e1045981ac47a7cd09481da..00fbaba37c542954f690b310a184cff985a05156 100644 (file)
@@ -44,7 +44,9 @@ class XlaContext : public ResourceBase {
 
   // Creates a new XlaContext.
   XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
-             bool allow_cpu_custom_calls, bool resolve_compile_time_constants);
+             bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
+             const std::function<TensorShape(const TensorShape&, DataType)>*
+                 variable_representation_shape_fn);
 
   // Virtual method defined by ResourceBase.
   string DebugString() override;
@@ -86,6 +88,11 @@ class XlaContext : public ResourceBase {
     return resources_;
   }
 
+  // Returns the XLA shape to be used to represent a variable of TF `shape`
+  // and `type`.
+  TensorShape VariableRepresentationShape(const TensorShape& shape,
+                                          DataType type) const;
+
   // Get an XLA lambda to compute Max. This is cached in the
   // XlaContext since it may be used by multiple Ops. There is a
   // separate specialization of the computation for each DataType.
@@ -133,6 +140,11 @@ class XlaContext : public ResourceBase {
   // Holds ownership of resources. The resources are not ordered.
   std::vector<std::unique_ptr<XlaResource>> resources_;
 
+  // A function that describes how variable shapes should be represented
+  // in XLA. Variable values will be reshaped to this shape. Must be non-null.
+  const std::function<TensorShape(const TensorShape&, DataType)>*
+      variable_representation_shape_fn_;
+
   // Cache of prebuilt computations indexed by their type.
   using ComputationMap = std::map<DataType, xla::Computation>;
 
index ee29158646fa96fe554d089e11d50afb47e3e300..c4bb90d58755f16672ca7c6a6738065be6330485 100644 (file)
@@ -302,10 +302,19 @@ Status XlaOpKernelContext::ReadVariableInput(
         "Type mismatch for read of variable ", variable->name(), ". Expected ",
         DataTypeString(type), "; got ", DataTypeString(variable->type()));
   }
-  *value = variable->value();
   if (shape) {
     *shape = variable->shape();
   }
+
+  XlaContext& xla_context = XlaContext::Get(context_);
+  TensorShape representation_shape = xla_context.VariableRepresentationShape(
+      variable->shape(), variable->type());
+  if (representation_shape == variable->shape()) {
+    *value = variable->value();
+  } else {
+    *value =
+        builder()->Reshape(variable->value(), variable->shape().dim_sizes());
+  }
   return Status::OK();
 }
 
@@ -400,8 +409,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
   return Status::OK();
 }
 
-Status XlaOpKernelContext::AssignVariable(
-    int input_index, DataType type, const xla::ComputationDataHandle& handle) {
+Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
+                                          xla::ComputationDataHandle handle) {
   TF_RET_CHECK(handle.handle() != 0);
 
   const XlaExpression* expression =
@@ -419,6 +428,13 @@ Status XlaOpKernelContext::AssignVariable(
       XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
 
   TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
+
+  XlaContext& xla_context = XlaContext::Get(context_);
+  TensorShape representation_shape =
+      xla_context.VariableRepresentationShape(shape, type);
+  if (shape != representation_shape) {
+    handle = builder()->Reshape(handle, representation_shape.dim_sizes());
+  }
   return variable->SetValue(handle);
 }
 
index e1fd0f55c6d2501b4813c90171630a8df567f78a..4e4b97e0cec8d16b9b5686a779b1285906765dbd 100644 (file)
@@ -175,7 +175,7 @@ class XlaOpKernelContext {
   // variable has been initialized with a different type or with a
   // different shape.
   Status AssignVariable(int input_index, DataType type,
-                        const xla::ComputationDataHandle& handle);
+                        xla::ComputationDataHandle handle);
 
   // Helper routines for the OP_REQUIRES macros
   void CtxFailure(const Status& s);