[TF:XLA] Improve/refactor the handling of resource types/shapes.
authorPeter Hawkins <phawkins@google.com>
Fri, 2 Feb 2018 19:31:06 +0000 (11:31 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 19:36:32 +0000 (11:36 -0800)
Previously we used an xla::Shape to track the shape of a resource (Variable, TensorArray, Stack) shape. The xla::Shape described how the resource was represented to XLA, e.g., as a (buffer, size) pair for a Stack resource.

Instead, separate the TensorFlow abstract shape representation from the XLA shape representation and track it separately. This leads to simpler and more readable code.

PiperOrigin-RevId: 184310694

19 files changed:
tensorflow/compiler/jit/kernels/xla_launch_op.cc
tensorflow/compiler/jit/xla_compilation_cache.cc
tensorflow/compiler/tf2xla/graph_compiler.cc
tensorflow/compiler/tf2xla/kernels/stack_ops.cc
tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
tensorflow/compiler/tf2xla/kernels/training_ops.cc
tensorflow/compiler/tf2xla/kernels/variable_ops.cc
tensorflow/compiler/tf2xla/kernels/while_op.cc
tensorflow/compiler/tf2xla/tf2xla.cc
tensorflow/compiler/tf2xla/xla_compiler.cc
tensorflow/compiler/tf2xla/xla_compiler.h
tensorflow/compiler/tf2xla/xla_compiler_test.cc
tensorflow/compiler/tf2xla/xla_context.cc
tensorflow/compiler/tf2xla/xla_context.h
tensorflow/compiler/tf2xla/xla_op_kernel.cc
tensorflow/compiler/tf2xla/xla_op_kernel.h
tensorflow/compiler/tf2xla/xla_resource.cc
tensorflow/compiler/tf2xla/xla_resource.h

index 17ae2bb..6353149 100644 (file)
@@ -376,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
     OP_REQUIRES(ctx,
                 write.input_index >= 0 && write.input_index < ctx->num_inputs(),
                 errors::Internal("Invalid input index for variable write."));
-    TensorShape write_shape;
-    OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape));
 
     gpu::DeviceMemoryBase buffer = output->buffer({output_num});
 
@@ -399,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
 
     // Looks up the owning Tensor by buffer address.
     OP_REQUIRES_OK(
-        ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape,
+        ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape,
                                                 variable->tensor()));
     ++output_num;
   }
index 21d3a54..6d854a9 100644 (file)
@@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args,
     XlaCompiler::Argument& arg = (*args)[input_num];
     arg.kind = XlaCompiler::Argument::kConstant;
     arg.type = input.dtype();
-    TF_RETURN_IF_ERROR(
-        TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
+    arg.shape = input.shape();
     arg.constant_value = input;
     ++input_num;
   }
@@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args,
       arg.constant_value = input;
     }
     arg.type = input.dtype();
-    TF_RETURN_IF_ERROR(
-        TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
+    arg.shape = input.shape();
     ++input_num;
   }
 
@@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args,
     if (variable_args[variable_id].present) {
       const Tensor& value = variable_args[variable_id].value;
       arg.type = value.dtype();
-      TF_RETURN_IF_ERROR(
-          TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape));
+      arg.shape = value.shape();
       arg.initialized = true;
     } else {
       // The values of uninitialized variables are not passed as inputs, since
@@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args,
       // uninitialized variables.
       arg.initialized = false;
       arg.type = DT_INVALID;
-      arg.shape = xla::Shape();
+      arg.shape = TensorShape();
     }
     ++input_num;
   }
index 02215b5..1418d95 100644 (file)
@@ -60,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
   for (int i = 0; i < args->size(); ++i) {
     XlaCompiler::Argument& arg = (*args)[i];
     arg.type = ctx->input_type(i);
-
-    TF_RETURN_IF_ERROR(
-        TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
+    arg.shape = ctx->InputShape(i);
 
     if (arg.type == DT_RESOURCE) {
       return errors::InvalidArgument(
index d77fb76..1a78c7a 100644 (file)
@@ -77,10 +77,8 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder,
     // Stack has not been initialized.
     xla::ComputationDataHandle zero =
         XlaHelpers::Zero(builder, resource->type());
-    TF_RETURN_IF_ERROR(resource->SetValue(
-        dtype,
-        builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()),
-                        builder->ConstantR0<int32>(0)})));
+    TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
+    TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
   } else {
     // Checks the expected shape matches the actual shape.
     TensorShape actual_shape;
@@ -119,8 +117,8 @@ class StackOp : public XlaOpKernel {
     string name = strings::StrCat("Stack: ", stack_name_);
     OP_REQUIRES_OK(
         ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
-                               value, &resource));
-    resource->set_tensor_array_size(size);
+                               TensorShape(), value, /*tensor_array_size=*/size,
+                               /*tensor_array_gradients=*/{}, &resource));
     ctx->SetResourceOutput(0, resource);
   }
 
@@ -164,11 +162,9 @@ class StackPushOp : public XlaOpKernel {
 
     // TODO(phawkins): We don't check the index is in bounds --- there is no
     // error mechanism in XLA.
-    OP_REQUIRES_OK(
-        ctx,
-        resource->SetValue(
-            dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices),
-                              b->Add(index, b->ConstantR0<int32>(1))})));
+    OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple(
+                            {b->DynamicUpdateSlice(ta, update, start_indices),
+                             b->Add(index, b->ConstantR0<int32>(1))})));
 
     ctx->SetOutput(0, value);
   }
@@ -208,7 +204,7 @@ class StackPopOp : public XlaOpKernel {
     xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
 
     index = b->Sub(index, b->ConstantR0<int32>(1));
-    OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index})));
+    OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
 
     // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
     auto start_indices =
index f0525a5..91c1694 100644 (file)
@@ -231,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
     OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
   }
 
   void Compile(XlaOpKernelContext* ctx) override {
@@ -252,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel {
     OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
                                             &strides_tensor));
 
-    DataType lhs_type;
     TensorShape lhs_shape;
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
+    xla::ComputationDataHandle lhs;
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
 
     const TensorShape rhs_shape = ctx->InputShape(4);
 
@@ -282,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel {
                     " does not match r-value shape ", rhs_shape.DebugString(),
                     ". Automatic broadcasting not yet implemented."));
 
-    xla::ComputationDataHandle lhs;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
-
     xla::ComputationDataHandle rhs = ctx->Input(4);
 
     gtl::InlinedVector<int64, 4> dimensions_to_reverse;
@@ -320,13 +318,14 @@ class StridedSliceAssignOp : public XlaOpKernel {
           lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
     }
 
-    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
+    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
   }
 
  private:
   int32 begin_mask_, end_mask_;
   int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
   DataType index_type_;
+  DataType dtype_;
 };
 
 REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
index 9224072..7cf9b79 100644 (file)
@@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
 
   TF_RET_CHECK(resource->tensor_array_size() >= 0)
       << resource->name() << " size " << resource->tensor_array_size();
-  TensorShape ta_shape;
-  ta_shape.AddDim(resource->tensor_array_size());
-  ta_shape.AppendShape(elem_shape);
 
   if (!resource->initialized()) {
     xla::ComputationDataHandle zero =
         XlaHelpers::Zero(builder, resource->type());
-    TF_RETURN_IF_ERROR(resource->SetValue(
-        dtype, builder->Broadcast(zero, ta_shape.dim_sizes())));
+
+    TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
+    TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
   } else {
     // Checks the elem_shape matches the TensorArray shape.
     auto shape_or_status = builder->GetShape(resource->value());
@@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
     TensorShape shape;
     TF_RETURN_IF_ERROR(
         XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+
+    TensorShape ta_shape;
+    ta_shape.AddDim(resource->tensor_array_size());
+    ta_shape.AppendShape(elem_shape);
     if (ta_shape != shape) {
       return errors::InvalidArgument(
           "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
@@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
 Status GetTensorArrayShape(const XlaResource* resource,
                            xla::ComputationBuilder* builder,
                            TensorShape* shape) {
-  TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
-  if (shape->dims() < 1) {
-    return errors::InvalidArgument("TensorArray rank must be >= 1");
-  }
+  *shape = resource->shape();
+  shape->InsertDim(0, resource->tensor_array_size());
   return Status::OK();
 }
 
@@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel {
     // Initializes the TensorArray value if we know the element shape.
     // Otherwise, defer initialization to the first write.
     xla::ComputationDataHandle value;
+    TensorShape shape;
     if (element_shape_.IsFullyDefined()) {
-      TensorShape shape;
       CHECK(element_shape_.AsTensorShape(&shape));
       TensorShape ta_shape;
       ta_shape.AddDim(size);
@@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel {
     string name = strings::StrCat("TensorArray: ", tensor_array_name_);
     OP_REQUIRES_OK(
         ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
-                               dtype_, value, &var));
-    var->set_tensor_array_size(size);
+                               dtype_, shape, value, /*tensor_array_size=*/size,
+                               /*tensor_array_gradients=*/{}, &var));
     ctx->SetResourceOutput(0, var);
 
     Tensor flow(DT_FLOAT, TensorShape({}));
@@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
     xla::ComputationDataHandle written =
         DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
 
-    OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written));
+    OP_REQUIRES_OK(ctx, resource->SetValue(written));
     ctx->SetOutput(0, flow);
   }
 
@@ -421,7 +421,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
       }
     }
 
-    OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta));
+    OP_REQUIRES_OK(ctx, resource->SetValue(ta));
     ctx->SetOutput(0, flow);
   }
 
@@ -525,9 +525,8 @@ class TensorArraySplitOp : public XlaOpKernel {
                                         value_shape.DebugString(), " vs. ",
                                         ta_shape.DebugString()));
 
-    OP_REQUIRES_OK(
-        ctx, resource->SetValue(
-                 dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()))));
+    OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
+                            ta, b->Reshape(value, ta_shape.dim_sizes()))));
 
     ctx->SetOutput(0, flow);
   }
index 5534d1b..f750f70 100644 (file)
@@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
   void Compile(XlaOpKernelContext* ctx) override {
     xla::ComputationDataHandle handle;
     xla::ComputationBuilder* b = ctx->builder();
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+    DataType type = ctx->input_type(1);
+    TensorShape var_shape;
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
+
+    TensorShape alpha_shape = ctx->InputShape(1);
+    OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+                errors::InvalidArgument("alpha is not a scalar: ",
+                                        alpha_shape.DebugString()));
+
+    TensorShape delta_shape = ctx->InputShape(2);
+    OP_REQUIRES(
+        ctx, var_shape.IsSameSize(delta_shape),
+        errors::InvalidArgument("var and delta do not have the same shape: ",
+                                var_shape.DebugString(), " vs ",
+                                delta_shape.DebugString()));
+
     handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
-    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
   }
 };
 REGISTER_XLA_OP(
@@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel {
 
     DataType type = ctx->input_type(2);
 
-    DataType var_type, accum_type;
     TensorShape var_shape, accum_shape;
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
-    OP_REQUIRES_OK(ctx,
-                   ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
-
-    OP_REQUIRES(
-        ctx, type == var_type && type == accum_type,
-        errors::InvalidArgument(
-            "Types of variable arguments to ResourceApplyMomentum must match: ",
-            DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
-            DataTypeString(accum_type)));
+    xla::ComputationDataHandle var, accum;
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
 
     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
                 errors::InvalidArgument(
@@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel {
                 errors::InvalidArgument("momentum is not a scalar: ",
                                         momentum_shape.DebugString()));
 
-    xla::ComputationDataHandle var, accum;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
-
     xla::ComputationDataHandle lr = ctx->Input(2);
     xla::ComputationDataHandle grad = ctx->Input(3);
     xla::ComputationDataHandle momentum = ctx->Input(4);
@@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel {
 
     DataType type = ctx->input_type(2);
 
-    DataType var_type, accum_type;
     TensorShape var_shape, accum_shape;
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
-    OP_REQUIRES_OK(ctx,
-                   ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
-
-    OP_REQUIRES(
-        ctx, type == var_type && type == accum_type,
-        errors::InvalidArgument(
-            "Types of variable arguments to ResourceApplyAdagrad must match: ",
-            DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
-            DataTypeString(accum_type)));
+    xla::ComputationDataHandle var, accum;
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
 
     OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
                 errors::InvalidArgument(
@@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
                     "var and grad do not have the same shape",
                     var_shape.DebugString(), " ", grad_shape.DebugString()));
 
-    xla::ComputationDataHandle var, accum;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
     xla::ComputationDataHandle lr = ctx->Input(2);
     xla::ComputationDataHandle grad = ctx->Input(3);
 
@@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel {
   }
 
   void Compile(XlaOpKernelContext* ctx) override {
-    DataType var_type, m_type, v_type;
     TensorShape var_shape, m_shape, v_shape;
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape));
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape));
-
-    OP_REQUIRES(
-        ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type,
-        errors::InvalidArgument(
-            "Types of variable arguments to ResourceApplyRMSProp must match: ",
-            DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ",
-            DataTypeString(m_type), " vs. ", DataTypeString(v_type)));
+    xla::ComputationDataHandle var, m, v;
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
 
     TensorShape beta1_power_shape = ctx->InputShape(3);
     TensorShape beta2_power_shape = ctx->InputShape(4);
@@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel {
                     "var and grad do not have the same shape",
                     var_shape.DebugString(), " ", grad_shape.DebugString()));
 
-    xla::ComputationDataHandle var, m, v;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m));
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v));
     xla::ComputationDataHandle beta1_power = ctx->Input(3);
     xla::ComputationDataHandle beta2_power = ctx->Input(4);
     xla::ComputationDataHandle lr = ctx->Input(5);
@@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
 
     DataType type = ctx->input_type(3);
 
-    DataType var_type, ms_type, mom_type;
     TensorShape var_shape, ms_shape, mom_shape;
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape));
-    OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape));
-
-    OP_REQUIRES(
-        ctx, type == var_type && type == ms_type && type == mom_type,
-        errors::InvalidArgument(
-            "Types of variable arguments to ResourceApplyRMSProp must match: ",
-            DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ",
-            DataTypeString(ms_type), " vs. ", DataTypeString(mom_type)));
+    xla::ComputationDataHandle var, ms, mom;
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
 
     TensorShape lr_shape = ctx->InputShape(3);
     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
@@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel {
                     "var and grad do not have the same shape",
                     var_shape.DebugString(), " ", grad_shape.DebugString()));
 
-    xla::ComputationDataHandle var, ms, mom;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms));
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom));
     xla::ComputationDataHandle lr = ctx->Input(3);
     xla::ComputationDataHandle rho = ctx->Input(4);
     xla::ComputationDataHandle momentum = ctx->Input(5);
@@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
                  bool has_l2_shrinkage) {
   xla::ComputationBuilder* b = ctx->builder();
 
-  DataType var_type, accum_type, linear_type;
   TensorShape var_shape, accum_shape, linear_shape;
-  OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
-  OP_REQUIRES_OK(ctx,
-                 ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
-  OP_REQUIRES_OK(ctx,
-                 ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape));
-
-  OP_REQUIRES(
-      ctx, dtype == var_type && dtype == accum_type && dtype == linear_type,
-      errors::InvalidArgument(
-          "Types of variable arguments to ResourceApplyFtrlV2 must match: ",
-          DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ",
-          DataTypeString(accum_type), " and ", DataTypeString(linear_type)));
+  xla::ComputationDataHandle var, accum, linear;
+  OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
+  OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
+  OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
 
   OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
               errors::InvalidArgument(
@@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
               errors::InvalidArgument("lr_power is not a scalar: ",
                                       lr_power_shape.DebugString()));
 
-  xla::ComputationDataHandle var, accum, linear;
-  OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
-  OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
-  OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear));
   xla::ComputationDataHandle grad = ctx->Input(3);
   xla::ComputationDataHandle lr = ctx->Input(4);
   xla::ComputationDataHandle l1 = ctx->Input(5);
index 68847ae..e4079eb 100644 (file)
@@ -33,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel {
  public:
   explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
   void Compile(XlaOpKernelContext* ctx) override {
-    xla::ComputationDataHandle handle;
-    bool initialized = ctx->ReadVariableInput(0, &handle).ok();
-    ctx->SetOutput(0, ctx->builder()->ConstantR0<bool>(initialized));
+    XlaResource* variable;
+    OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
+    ctx->SetOutput(0,
+                   ctx->builder()->ConstantR0<bool>(variable->initialized()));
   }
 };
 REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
 
 class ReadVariableOp : public XlaOpKernel {
  public:
-  explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+  explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+  }
+
   void Compile(XlaOpKernelContext* ctx) override {
     xla::ComputationDataHandle handle;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+    OP_REQUIRES_OK(
+        ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
     ctx->SetOutput(0, handle);
   }
+
+ private:
+  DataType dtype_;
 };
 REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
 
@@ -65,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel {
  public:
   explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
   void Compile(XlaOpKernelContext* ctx) override {
+    DataType type = ctx->input_type(1);
     xla::ComputationDataHandle handle;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+    OP_REQUIRES_OK(ctx,
+                   ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
     handle = ctx->builder()->Add(handle, ctx->Input(1));
-    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
   }
 };
 REGISTER_XLA_OP(
@@ -79,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel {
  public:
   explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
   void Compile(XlaOpKernelContext* ctx) override {
+    DataType type = ctx->input_type(1);
     xla::ComputationDataHandle handle;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+    OP_REQUIRES_OK(ctx,
+                   ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
     handle = ctx->builder()->Sub(handle, ctx->Input(1));
-    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+    OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
   }
 };
 REGISTER_XLA_OP(
@@ -95,28 +107,19 @@ class ResourceGatherOp : public XlaOpKernel {
   void Compile(XlaOpKernelContext* ctx) override {
     xla::ComputationBuilder* builder = ctx->builder();
 
-    // Get the shape of the resource tensor.
-    TensorShape resource_shape;
-    DataType resource_dtype;
-    OP_REQUIRES_OK(
-        ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape));
-
-    DataType expected_output_dtype = ctx->expected_output_dtype(0);
-    OP_REQUIRES(ctx, resource_dtype == expected_output_dtype,
-                errors::InvalidArgument(
-                    "Variable dtype is ", DataTypeString(resource_dtype),
-                    " but expected output dtype is ",
-                    DataTypeString(expected_output_dtype), "."));
+    DataType type = ctx->expected_output_dtype(0);
 
+    TensorShape resource_shape;
     xla::ComputationDataHandle resource_handle;
-    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle));
+    OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
+                                               &resource_handle));
 
     auto indices = ctx->Input(1);
     auto indices_shape = ctx->InputShape(1);
     DataType index_type = ctx->input_type(1);
     xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
-        ctx, resource_handle, resource_shape, indices, indices_shape, 0,
-        resource_dtype, index_type, builder);
+        ctx, resource_handle, resource_shape, indices, indices_shape, 0, type,
+        index_type, builder);
     ctx->SetOutput(0, gather);
   }
 };
index 4a711e4..0ff1b65 100644 (file)
@@ -58,9 +58,8 @@ Status MakeXlaCompilerArgumentsFromInputs(
       }
 
       arg.type = resource->type();
-      if (arg.initialized) {
-        TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape));
-      } else {
+      arg.shape = resource->shape();
+      if (!arg.initialized) {
         *has_uninitialized_vars = true;
       }
       arg.tensor_array_size = resource->tensor_array_size();
@@ -70,14 +69,13 @@ Status MakeXlaCompilerArgumentsFromInputs(
       arg.name = resource->name();
       VLOG(2) << "    resource " << resource->name()
               << " type: " << DataTypeString(arg.type)
-              << " shape: " << xla::ShapeUtil::HumanString(arg.shape)
+              << " shape: " << arg.shape.DebugString()
               << " initialized: " << arg.initialized;
 
     } else {
       arg.kind = XlaCompiler::Argument::kParameter;
       arg.type = ctx->input_type(i);
-      TF_RETURN_IF_ERROR(
-          TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
+      arg.shape = ctx->InputShape(i);
     }
   }
   return Status::OK();
@@ -154,17 +152,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
       XlaCompiler::Argument& arg = arguments[update.input_index];
       if (!arg.initialized) {
         VLOG(2) << "Update shape for argument " << update.input_index << " "
-                << xla::ShapeUtil::HumanString(update.shape);
+                << update.shape.DebugString();
         arg.initialized = true;
 
-        xla::Shape shape = update.shape;
-        if (!update.tensor_array_gradients_accessed.empty()) {
-          shape = xla::ShapeUtil::GetTupleElementShape(shape, 0);
-        }
-        std::unique_ptr<xla::Literal> zero =
-            xla::Literal::CreateFromShape(shape);
-        OP_REQUIRES_OK(ctx, resource->SetValue(
-                                update.type, builder->ConstantLiteral(*zero)));
+        arg.shape = update.shape;
+        OP_REQUIRES_OK(ctx,
+                       resource->SetTypeAndShape(update.type, update.shape));
+
+        OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
       }
 
       // Add any TensorArray gradients touched by the body to the enclosing
@@ -182,9 +177,6 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
       for (const auto& gradient : resource->tensor_array_gradients()) {
         arg.tensor_array_gradients.insert(gradient.first);
       }
-
-      // Recompute the argument shape.
-      OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape));
     }
     // Recompile the body with the "correct" resource shapes.
     VLOG(1) << "Recompiling body with corrected resource shapes";
@@ -292,13 +284,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
       OP_REQUIRES_OK(ctx,
                      resource->SetFromPack(
                          arguments[update.input_index].tensor_array_gradients,
-                         builder->GetTupleElement(while_result, pos),
-                         /*reset_initial_values=*/false, builder));
+                         builder->GetTupleElement(while_result, pos), builder));
     }
     VLOG(2) << "Loop-carried variable: pos: " << update.input_index
             << " name: " << resource->name() << " modified: " << update.modified
             << " type: " << DataTypeString(update.type)
-            << " shape: " << xla::ShapeUtil::HumanString(update.shape);
+            << " shape: " << update.shape.DebugString();
     // Copies the identity of the resource variable from input to output
     // unchanged, even if the variable was not modified.
     ctx->op_kernel_context()->set_output(
index 906f229..6051d7d 100644 (file)
@@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph,
     XlaCompiler::Argument arg;
     arg.kind = XlaCompiler::Argument::kParameter;
     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
-    TensorShape shape;
-    TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
-    TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
+    TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
     TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
     xla_args->push_back(arg);
   }
index 69b2654..c5b4ec5 100644 (file)
@@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types,
 
 bool XlaCompiler::Argument::operator==(
     const XlaCompiler::Argument& other) const {
-  if (std::tie(kind, resource_kind, type, name, tensor_array_size,
+  if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
                tensor_array_gradients) !=
       std::tie(other.kind, other.resource_kind, other.type, other.name,
-               other.tensor_array_size, other.tensor_array_gradients)) {
+               other.initialized, other.tensor_array_size,
+               other.tensor_array_gradients)) {
     return false;
   }
-  if (!xla::ShapeUtil::Equal(shape, other.shape)) {
+  if (shape != other.shape) {
     return false;
   }
   if (constant_value.shape() != other.constant_value.shape()) {
@@ -230,6 +231,64 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
   return Status::OK();
 }
 
+// Computes the XLA shape for argument 'arg'.
+/*static*/ Status XlaCompiler::XLAShapeForArgument(
+    const XlaCompiler::Argument& arg, xla::Shape* xla_shape) {
+  switch (arg.kind) {
+    case XlaCompiler::Argument::kConstant:
+      return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
+                                   xla_shape);
+    case XlaCompiler::Argument::kParameter:
+      return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
+    case XlaCompiler::Argument::kResource: {
+      TF_RET_CHECK(arg.initialized);
+
+      switch (arg.resource_kind) {
+        case XlaResource::kVariable:
+          return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
+        case XlaResource::kTensorArray: {
+          if (arg.tensor_array_size < 0) {
+            return errors::InvalidArgument(
+                "Negative tensor_array_size in XLAShapeForArgument");
+          }
+          TensorShape shape;
+          shape.AddDim(arg.tensor_array_size);
+          shape.AppendShape(arg.shape);
+          TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
+
+          if (!arg.tensor_array_gradients.empty()) {
+            std::vector<xla::Shape> tuple_shape(
+                arg.tensor_array_gradients.size() + 1, *xla_shape);
+            *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
+          }
+          return Status::OK();
+        }
+        case XlaResource::kStack: {
+          if (arg.tensor_array_size < 0) {
+            return errors::InvalidArgument(
+                "Negative tensor_array_size in XLAShapeForArgument");
+          }
+          TensorShape shape;
+          shape.AddDim(arg.tensor_array_size);
+          shape.AppendShape(arg.shape);
+          xla::Shape buffer_shape;
+          TF_RETURN_IF_ERROR(
+              TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
+          *xla_shape = xla::ShapeUtil::MakeTupleShape(
+              {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
+          return Status::OK();
+        }
+
+        case XlaResource::kInvalid:
+          return errors::Internal(
+              "Invalid resource type in XLAShapeForArgument()");
+      }
+    }
+    case XlaCompiler::Argument::kInvalid:
+      return errors::Internal("Invalid argument type in XLAShapeForArgument()");
+  }
+}
+
 namespace {
 
 Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
@@ -275,8 +334,9 @@ Status BuildArguments(const Graph& graph,
 
   // Argument numbers of arguments and resources that are to be passed to the
   // XLA computation as runtime parameters.
-  std::vector<int> parameters, resources;
-  parameters.reserve(args.size());
+  input_mapping->clear();
+  input_mapping->reserve(args.size());
+  std::vector<int> resources;
   resources.reserve(args.size());
 
   // Fills in constant arguments, and computes non-constant argument order.
@@ -290,18 +350,20 @@ Status BuildArguments(const Graph& graph,
         // TODO(phawkins): this code assumes that resource arguments do not
         // alias.
         XlaResource* resource;
-        TF_RETURN_IF_ERROR(
-            context->CreateResource(arg.resource_kind, i, arg.name, arg.type,
-                                    xla::ComputationDataHandle(), &resource));
-        resource->set_tensor_array_size(arg.tensor_array_size);
+        TF_RETURN_IF_ERROR(context->CreateResource(
+            arg.resource_kind, i, arg.name, arg.type, arg.shape,
+            xla::ComputationDataHandle(),
+            /*tensor_array_size=*/arg.tensor_array_size,
+            /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
         arg_expression.set_resource(resource);
         if (arg.initialized) {
           resources.push_back(i);
         }
         break;
-      case XlaCompiler::Argument::kParameter:
-        parameters.push_back(i);
+      case XlaCompiler::Argument::kParameter: {
+        input_mapping->push_back(i);
         break;
+      }
       case XlaCompiler::Argument::kConstant:
         arg_expression.set_constant_value(arg.constant_value);
         break;
@@ -312,19 +374,17 @@ Status BuildArguments(const Graph& graph,
 
   // Append parameters containing variable values after the other runtime
   // parameters.
-  parameters.insert(parameters.end(), resources.begin(), resources.end());
-  if (parameters.empty()) {
+  input_mapping->insert(input_mapping->end(), resources.begin(),
+                        resources.end());
+  if (input_mapping->empty()) {
     return Status::OK();
   }
 
-  std::vector<xla::Shape> arg_shapes;
-  arg_shapes.reserve(parameters.size());
-  input_mapping->resize(parameters.size());
-  for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
-    const XlaCompiler::Argument& arg = args[parameters[i]];
+  std::vector<xla::Shape> arg_shapes(input_mapping->size());
+  for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
     // Computes the shapes of non-constant arguments.
-    arg_shapes.push_back(arg.shape);
-    (*input_mapping)[i] = parameters[i];
+    TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument(
+        args[(*input_mapping)[i]], &arg_shapes[i]));
   }
 
   if (use_tuple_arg) {
@@ -354,13 +414,13 @@ Status BuildArguments(const Graph& graph,
   }
 
   // Build parameter handles for non-constant arguments.
-  std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
+  std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
   if (use_tuple_arg) {
     xla::ComputationDataHandle tuple;
     if (is_entry_computation) {
       xla::OpSharding tuple_sharding;
       tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
-      for (int64 parameter : parameters) {
+      for (int64 parameter : *input_mapping) {
         const int core = (*arg_cores)[parameter];
         const int root_device = 0;
         *tuple_sharding.add_tuple_shardings() =
@@ -373,16 +433,16 @@ Status BuildArguments(const Graph& graph,
     } else {
       tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
     }
-    for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
-      const int core = (*arg_cores)[parameters[i]];
+    for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+      const int core = (*arg_cores)[input_mapping->at(i)];
       xla::ScopedShardingAssignment assign_sharding(
           builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
                               : xla::sharding_builder::AssignDevice(core));
       arg_handles[i] = builder->GetTupleElement(tuple, i);
     }
   } else {
-    for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
-      const int core = (*arg_cores)[parameters[i]];
+    for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+      const int core = (*arg_cores)[input_mapping->at(i)];
       xla::ScopedShardingAssignment assign_sharding(
           builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
                               : xla::sharding_builder::AssignDevice(core));
@@ -393,19 +453,18 @@ Status BuildArguments(const Graph& graph,
 
   // Fill in the handles in non-constant arguments.
   VLOG(2) << "XLA computation inputs:";
-  for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
-    const XlaCompiler::Argument& arg = args[parameters[i]];
+  for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+    const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
     VLOG(2) << "  XLA arg " << i
             << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
-            << " name: " << arg.name << " TF arg " << parameters[i];
-    XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
+            << " name: " << arg.name << " TF arg " << input_mapping->at(i);
+    XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)];
     switch (arg.kind) {
       case XlaCompiler::Argument::kResource: {
         TF_RET_CHECK(arg.initialized);
         XlaResource* resource = arg_expression.resource();
-        TF_RETURN_IF_ERROR(
-            resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i],
-                                  /*reset_initial_values=*/true, builder));
+        TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
+                                                 arg_handles[i], builder));
         VLOG(2) << "    resource: num_gradients: "
                 << arg.tensor_array_gradients.size();
         break;
@@ -486,6 +545,7 @@ Status BuildComputation(
       XlaCompiler::ResourceUpdate& update = resource_updates->back();
       update.input_index = resource->arg_num();
       update.type = resource->type();
+      update.shape = resource->shape();
       update.modified = modified;
       for (const auto& grad : resource->tensor_array_gradients()) {
         update.tensor_array_gradients_accessed.insert(grad.first);
@@ -616,13 +676,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
       ++computation_output;
     }
   }
-
-  for (std::vector<ResourceUpdate>::size_type i = 0;
-       i < result->resource_updates.size(); ++i) {
-    result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
-        result->xla_output_shape, computation_output);
-    ++computation_output;
-  }
   return Status::OK();
 }
 
index 30d3c05..b86c82c 100644 (file)
@@ -104,9 +104,17 @@ class XlaCompiler {
     // is the type of the variable's value, not DT_RESOURCE.
     DataType type;
 
-    // The shape of the argument. If the argument is a resource, this is the
-    // shape of the resource's value.
-    xla::Shape shape;
+    // The shape of the argument. For:
+    // * a parameter: the shape of the parameter.
+    // * a constant: ignored; the shape given by constant_value is used
+    //     instead.
+    // * an uninitialized resource: ignored. We don't yet know the shape of an
+    //     uninitialized resource (otherwise we would have initialized it!)
+    // * an initialized variable: the shape of the variable's value.
+    // * an initialized TensorArray or Stack resource: the shape of an entry in
+    //   the TensorArray/Stack. Note this is the size of a single entry, not the
+    //   XLA data structure that represents the complete stack/array.
+    TensorShape shape;
 
     // The value of the argument, if it is a compile-time constant. Must be a
     // host-memory tensor.
@@ -175,8 +183,9 @@ class XlaCompiler {
     int input_index;
 
     // Type and shape of the tensor to be written back.
+    // The `shape` field has the same meaning as the Argument::shape field.
     DataType type;
-    xla::Shape shape;
+    TensorShape shape;
 
     // Was the value of the variable modified by the computation?
     // (Always true, unless `return_updated_values_for_all_resources` is true.)
@@ -266,11 +275,10 @@ class XlaCompiler {
                       const std::vector<Argument>& args,
                       CompilationResult* result);
 
-  Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func,
-                          const std::vector<DataType>& types,
-                          const std::vector<TensorShape>& shapes,
-                          const std::vector<const XlaExpression*>& expressions,
-                          std::vector<Argument>* args);
+  // Returns the shape of the XLA parameter for an argument 'arg'.
+  // See the class comment for more details about the argument passing
+  // convention.
+  static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
 
   // Retrieves the channel handle associated with `key`. Allocates
   // a new channel handle if none exists.
index 7ebe4b7..65de4db 100644 (file)
@@ -191,10 +191,10 @@ TEST_F(XlaCompilerTest, Simple) {
   std::vector<XlaCompiler::Argument> args(2);
   args[0].kind = XlaCompiler::Argument::kParameter;
   args[0].type = DT_INT32;
-  args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+  args[0].shape = TensorShape({2});
   args[1].kind = XlaCompiler::Argument::kParameter;
   args[1].type = DT_INT32;
-  args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+  args[1].shape = TensorShape({2});
 
   // Compiles the graph.
   XlaCompiler compiler(DefaultOptions());
@@ -242,10 +242,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
   std::vector<XlaCompiler::Argument> args(2);
   args[0].kind = XlaCompiler::Argument::kParameter;
   args[0].type = DT_INT32;
-  args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+  args[0].shape = TensorShape({2});
   args[1].kind = XlaCompiler::Argument::kParameter;
   args[1].type = DT_INT32;
-  args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+  args[1].shape = TensorShape({2});
 
   // Compiles the graph.
   XlaCompiler compiler(DefaultOptions());
@@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
   std::vector<XlaCompiler::Argument> args(1);
   args[0].kind = XlaCompiler::Argument::kParameter;
   args[0].type = DT_INT32;
-  args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+  args[0].shape = TensorShape({2});
 
   XlaCompiler::Options options = DefaultOptions();
   XlaCompiler compiler(options);
@@ -373,7 +373,7 @@ TEST_F(XlaCompilerTest, ResourceManager) {
   std::vector<XlaCompiler::Argument> args(1);
   args[0].kind = XlaCompiler::Argument::kParameter;
   args[0].type = DT_INT32;
-  args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+  args[0].shape = TensorShape({2});
 
   DummyResourceForTest* resource = new DummyResourceForTest();
 
@@ -420,7 +420,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
     std::vector<XlaCompiler::Argument> args(1);
     args[0].kind = XlaCompiler::Argument::kParameter;
     args[0].type = DT_INT32;
-    args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+    args[0].shape = TensorShape({2});
 
     // Compiles the graph.
     auto options = DefaultOptions();
@@ -472,9 +472,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
   args[0].resource_kind = XlaResource::kTensorArray;
   args[0].initialized = true;
   args[0].type = DT_INT32;
-  args[0].shape = xla::ShapeUtil::MakeTupleShape(
-      {xla::ShapeUtil::MakeShape(xla::S32, {2}),
-       xla::ShapeUtil::MakeShape(xla::S32, {2})});
+  args[0].shape = TensorShape({});
   args[0].tensor_array_size = 2;
   args[0].tensor_array_gradients = {"grad2"};
 
@@ -540,9 +538,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
   args[0].resource_kind = XlaResource::kTensorArray;
   args[0].initialized = true;
   args[0].type = DT_INT32;
-  args[0].shape = xla::ShapeUtil::MakeTupleShape(
-      {xla::ShapeUtil::MakeShape(xla::S32, {2}),
-       xla::ShapeUtil::MakeShape(xla::S32, {2})});
+  args[0].shape = TensorShape({});
   args[0].tensor_array_size = 2;
   args[0].tensor_array_gradients = {"grad1"};
 
@@ -574,9 +570,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
   args[0].resource_kind = XlaResource::kTensorArray;
   args[0].initialized = true;
   args[0].type = DT_INT32;
-  args[0].shape = xla::ShapeUtil::MakeTupleShape(
-      {xla::ShapeUtil::MakeShape(xla::S32, {2}),
-       xla::ShapeUtil::MakeShape(xla::S32, {2})});
+  args[0].shape = TensorShape({});
   args[0].tensor_array_size = 2;
   args[0].tensor_array_gradients = {"grad1"};
 
index e8d17e2..7387895 100644 (file)
@@ -103,12 +103,14 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
 
 xla::ComputationBuilder* XlaContext::builder() { return builder_; }
 
-Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
-                                  string name, DataType type,
-                                  const xla::ComputationDataHandle& handle,
-                                  XlaResource** resource) {
+Status XlaContext::CreateResource(
+    XlaResource::Kind kind, int arg_num, string name, DataType type,
+    TensorShape shape, const xla::ComputationDataHandle& handle,
+    int64 tensor_array_size, const std::set<string>& tensor_array_gradients,
+    XlaResource** resource) {
   resources_.emplace_back(
-      new XlaResource(kind, arg_num, std::move(name), type, handle));
+      new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
+                      handle, tensor_array_size, tensor_array_gradients));
   *resource = resources_.back().get();
   return Status::OK();
 }
index 1a7dafe..fac0352 100644 (file)
@@ -71,11 +71,15 @@ class XlaContext : public ResourceBase {
   Status AddConstRetval(int retval_index, DataType dtype,
                         const xla::Literal& literal);
 
-  // Creates a resource with resource `kind` and initial type `type` and
-  // value `handle`. `name` is a descriptive name for use in error messages.
+  // Creates a resource with resource `kind` and initial value `handle`. `name`
+  // is a descriptive name for use in error messages. See the `XlaResource`
+  // constructor for a description of the remaining arguments.
   // Fails if the resource already exists.
   Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
-                        DataType type, const xla::ComputationDataHandle& handle,
+                        DataType type, TensorShape shape,
+                        const xla::ComputationDataHandle& handle,
+                        int64 tensor_array_size,
+                        const std::set<string>& tensor_array_gradients,
                         XlaResource** resource);
 
   const std::vector<std::unique_ptr<XlaResource>>& resources() {
index ee0aed6..ee29158 100644 (file)
@@ -286,7 +286,8 @@ Status XlaOpKernelContext::ConstantInputList(
 }
 
 Status XlaOpKernelContext::ReadVariableInput(
-    int index, xla::ComputationDataHandle* value) {
+    int index, DataType type, TensorShape* shape,
+    xla::ComputationDataHandle* value) {
   const Tensor& tensor = context_->input(index);
   const XlaExpression* expression = CastExpressionFromTensor(tensor);
   XlaResource* variable = expression->resource();
@@ -296,7 +297,15 @@ Status XlaOpKernelContext::ReadVariableInput(
     return errors::InvalidArgument("Read of uninitialized variable ",
                                    variable->name());
   }
+  if (variable->type() != type) {
+    return errors::InvalidArgument(
+        "Type mismatch for read of variable ", variable->name(), ". Expected ",
+        DataTypeString(type), "; got ", DataTypeString(variable->type()));
+  }
   *value = variable->value();
+  if (shape) {
+    *shape = variable->shape();
+  }
   return Status::OK();
 }
 
@@ -312,12 +321,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
                                    variable->name());
   }
   *type = variable->type();
-  auto shape_or_status = builder()->GetShape(variable->value());
-  if (!shape_or_status.ok()) {
-    return shape_or_status.status();
-  }
-  TF_RETURN_IF_ERROR(
-      XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape));
+  *shape = variable->shape();
   return Status::OK();
 }
 
@@ -405,7 +409,17 @@ Status XlaOpKernelContext::AssignVariable(
   XlaResource* variable = expression->resource();
   TF_RET_CHECK(variable != nullptr);
   TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
-  return variable->SetValue(type, handle);
+
+  auto shape_or_status = builder()->GetShape(handle);
+  if (!shape_or_status.ok()) {
+    return shape_or_status.status();
+  }
+  TensorShape shape;
+  TF_RETURN_IF_ERROR(
+      XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+
+  TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
+  return variable->SetValue(handle);
 }
 
 XlaCompiler* XlaOpKernelContext::compiler() const {
index 6d3b6db..e1fd0f5 100644 (file)
@@ -164,11 +164,16 @@ class XlaOpKernelContext {
                                  TensorShape* shape) const;
 
   // Reads the current value of the resouce variable referred to by input
-  // 'index'.
-  Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
+  // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
+  // variable. Returns an error if the variable has not been initialized, or if
+  // its type does not match `type`.
+  Status ReadVariableInput(int index, DataType type, TensorShape* shape,
+                           xla::ComputationDataHandle* value);
 
   // Assigns the value `handle` to the variable referenced by input
-  // `input_index`. Marks the operator as having side effects.
+  // `input_index`. The variable must be of `type`. Returns an error if the
+  // variable has been initialized with a different type or with a
+  // different shape.
   Status AssignVariable(int input_index, DataType type,
                         const xla::ComputationDataHandle& handle);
 
index 9abac8b..c2075b4 100644 (file)
@@ -25,51 +25,99 @@ limitations under the License.
 
 namespace tensorflow {
 
-XlaResource::XlaResource(Kind kind, int arg_num, string name,
-                         DataType initial_type,
-                         const xla::ComputationDataHandle& initial_value)
+XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
+                         TensorShape shape,
+                         const xla::ComputationDataHandle& initial_value,
+                         int64 tensor_array_size,
+                         const std::set<string>& tensor_array_gradients)
     : kind_(kind),
       arg_num_(arg_num),
       name_(std::move(name)),
-      type_(initial_type),
+      type_(type),
+      shape_(std::move(shape)),
       value_(initial_value),
-      initial_value_(initial_value) {
+      initial_value_(initial_value),
+      tensor_array_size_(tensor_array_size) {
   CHECK(kind_ != kInvalid);
+
+  for (const string& gradient : tensor_array_gradients) {
+    tensor_array_gradients_[gradient].reset(
+        new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
+                        /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
+                        type_, shape_, xla::ComputationDataHandle(),
+                        tensor_array_size_, /*tensor_array_gradients=*/{}));
+  }
 }
 
-Status XlaResource::SetValue(DataType type,
-                             const xla::ComputationDataHandle& value) {
-  if (type_ == DT_INVALID && type == DT_INVALID) {
-    return errors::InvalidArgument("Attempted to initialized resource ", name_,
-                                   " to an invalid type");
+Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
+  if (type == DT_INVALID) {
+    return errors::InvalidArgument("Attempted to set type of resource '", name_,
+                                   "'' to an invalid type");
   }
-  if (type_ != DT_INVALID && type_ != type) {
+  if (initialized() && type_ != type) {
     return errors::InvalidArgument("Type of resource ", name_,
                                    " cannot be changed after initialization: "
                                    "old type was ",
                                    DataTypeString(type_), ", new type is ",
                                    DataTypeString(type));
   }
+  if (initialized() && shape_ != shape) {
+    return errors::InvalidArgument("Shape of resource ", name_,
+                                   " cannot be changed after initialization: "
+                                   "old shape was ",
+                                   shape_.DebugString(), ", new shape is ",
+                                   shape.DebugString());
+  }
   type_ = type;
-  value_ = value;
+  shape_ = shape;
   return Status::OK();
 }
 
-Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder,
-                                xla::Shape* shape) const {
-  auto shape_or_status = builder->GetShape(value_);
-  if (!shape_or_status.ok()) {
-    return shape_or_status.status();
+Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
+  if (type_ == DT_INVALID) {
+    return errors::InvalidArgument(
+        "Resource '", name_,
+        "' must be initialized with a valid type before use.");
   }
-  *shape = *shape_or_status.ValueOrDie();
+  value_ = value;
   return Status::OK();
 }
 
-Status XlaResource::GetShape(xla::ComputationBuilder* builder,
-                             TensorShape* shape) const {
-  xla::Shape xla_shape;
-  TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape));
-  TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape));
+Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
+  if (type_ == DT_INVALID) {
+    return errors::InvalidArgument(
+        "Resource '", name_,
+        "' must be initialized with a valid type before use.");
+  }
+  switch (kind_) {
+    case kVariable: {
+      value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
+                                  shape_.dim_sizes());
+      break;
+    }
+    case kTensorArray: {
+      TensorShape ta_shape;
+      ta_shape.AddDim(tensor_array_size_);
+      ta_shape.AppendShape(shape_);
+      value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
+                                  ta_shape.dim_sizes());
+      break;
+    }
+    case kStack: {
+      TensorShape ta_shape;
+      ta_shape.AddDim(tensor_array_size_);
+      ta_shape.AppendShape(shape_);
+      value_ =
+          builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
+                                             ta_shape.dim_sizes()),
+                          builder->ConstantR0<int32>(0)});
+      break;
+    }
+
+    case kInvalid:
+    default:
+      LOG(FATAL) << "Invalid resource type";
+  }
   return Status::OK();
 }
 
@@ -82,36 +130,20 @@ Status XlaResource::GetOrCreateTensorArrayGradient(
   std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
   if (!gradient) {
     TensorShape ta_shape;
-    TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape));
+    ta_shape.AddDim(tensor_array_size_);
+    ta_shape.AppendShape(shape_);
     xla::ComputationDataHandle gradient_value = builder->Broadcast(
         XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
     gradient.reset(
         new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
                         /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
-                        type_, gradient_value));
-    gradient->tensor_array_size_ = tensor_array_size_;
+                        type_, shape_, gradient_value, tensor_array_size_,
+                        /*tensor_array_gradients=*/{}));
   }
   *gradient_out = gradient.get();
   return Status::OK();
 }
 
-Status XlaResource::PackedShape(xla::ComputationBuilder* builder,
-                                xla::Shape* packed_shape) const {
-  if (tensor_array_gradients_.empty()) {
-    return GetXlaShape(builder, packed_shape);
-  }
-  TF_RET_CHECK(kind_ == kTensorArray);
-  std::vector<xla::Shape> elem_shapes(1 + tensor_array_gradients_.size());
-  int pos = 0;
-  TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++]));
-  for (const auto& gradient : tensor_array_gradients_) {
-    TF_RETURN_IF_ERROR(
-        gradient.second->GetXlaShape(builder, &elem_shapes[pos++]));
-  }
-  *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
-  return Status::OK();
-}
-
 Status XlaResource::Pack(xla::ComputationDataHandle* pack,
                          xla::ComputationBuilder* builder) const {
   if (tensor_array_gradients_.empty()) {
@@ -130,27 +162,32 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack,
 
 Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
                                 const xla::ComputationDataHandle& pack,
-                                bool reset_initial_values,
                                 xla::ComputationBuilder* builder) {
   if (gradient_sources.empty()) {
+    if (!initialized()) {
+      initial_value_ = pack;
+    }
     value_ = pack;
   } else {
     TF_RET_CHECK(kind_ == kTensorArray);
     int pos = 0;
-    value_ = builder->GetTupleElement(pack, pos++);
+    auto v = builder->GetTupleElement(pack, pos++);
+    if (!initialized()) {
+      initial_value_ = v;
+    }
+    value_ = v;
+
     for (const auto& source : gradient_sources) {
       XlaResource* gradient;
       TF_RETURN_IF_ERROR(
           GetOrCreateTensorArrayGradient(source, builder, &gradient));
-      gradient->value_ = builder->GetTupleElement(pack, pos++);
-      if (reset_initial_values) {
-        gradient->initial_value_ = gradient->value_;
+      auto v = builder->GetTupleElement(pack, pos++);
+      if (!gradient->initialized()) {
+        gradient->initial_value_ = v;
       }
+      gradient->value_ = v;
     }
   }
-  if (reset_initial_values) {
-    initial_value_ = value_;
-  }
   return Status::OK();
 }
 
index 6b46089..1bb2c72 100644 (file)
@@ -36,8 +36,11 @@ class XlaResource {
     kStack,
   };
 
-  XlaResource(Kind kind, int arg_num, string name, DataType initial_type,
-              const xla::ComputationDataHandle& initial_value);
+  XlaResource(Kind kind, int arg_num, string name, DataType type,
+              TensorShape shape,
+              const xla::ComputationDataHandle& initial_value,
+              int64 tensor_array_size,
+              const std::set<string>& tensor_array_gradients);
 
   XlaResource(const XlaResource&) = delete;
   XlaResource(XlaResource&&) = delete;
@@ -60,6 +63,12 @@ class XlaResource {
   // a resource is first initialized we do not yet know its type, so we keep
   // track of its type dynamically.
   DataType type() const { return type_; }
+
+  // Shape of the resource. For an uninitialized resource, this is ignored.
+  // For a Variable, this is the shape of the value. For a TensorArray or Stack
+  // this is the shape of each entry in the TensorArray/Stack.
+  const TensorShape& shape() const { return shape_; }
+
   const xla::ComputationDataHandle& value() const { return value_; }
 
   // Value of the resource at computation entry. Used to detect which
@@ -68,17 +77,19 @@ class XlaResource {
     return initial_value_;
   }
 
+  // A variable is initialized if it has a value.
   bool initialized() const { return value_.handle() > 0; }
 
-  // Sets the current type/value of the resource.
-  Status SetValue(DataType type, const xla::ComputationDataHandle& value);
+  // Sets the type and shape of the resource. The type and shape of a resource
+  // must not change once the variable has been initialized.
+  Status SetTypeAndShape(DataType type, const TensorShape& shape);
 
-  // Returns the shape of the resource as an xla::Shape.
-  Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const;
+  // Sets the current value of the resource. Returns an error if the type is not
+  // set to a valid value.
+  Status SetValue(const xla::ComputationDataHandle& value);
 
-  // Returns the shape of the resource as an TensorShape. Fails if the shape is
-  // not representable as a TensorShape.
-  Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const;
+  // Sets the current value of the resource to an all-zero value.
+  Status SetZeroValue(xla::ComputationBuilder* builder);
 
   // Looks up the gradient for `source`, or creates it if it does not already
   // exist. The call target must be an initialized TensorArray resource. A
@@ -96,10 +107,6 @@ class XlaResource {
   Status Pack(xla::ComputationDataHandle* pack,
               xla::ComputationBuilder* builder) const;
 
-  // Returns the shape of the `pack` value computed by `Pack()`.
-  Status PackedShape(xla::ComputationBuilder* builder,
-                     xla::Shape* packed_shape) const;
-
   // Updates the resource with values from `pack`. If `gradient_sources` is
   // non-empty, treats `pack` as a tuple that represents a TensorArray and
   // its gradients, and unpacks and updates the gradient resources.
@@ -108,14 +115,14 @@ class XlaResource {
   // Opposite of Pack().
   Status SetFromPack(const std::set<string>& gradient_sources,
                      const xla::ComputationDataHandle& pack,
-                     bool reset_initial_values,
                      xla::ComputationBuilder* builder);
 
-  // TensorArray-specific fields
+  // TensorArray and Stack specific fields
 
   // 'tensor_array_size' stores the expected size of the TensorArray or Stack.
   // We need to store this since sometimes TensorArrays must be initialized
   // lazily since we do not know the element shape at construction time.
+  // Used by both TensorArrays and Stacks.
   int64 tensor_array_size() const { return tensor_array_size_; }
   void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
 
@@ -136,6 +143,7 @@ class XlaResource {
   const string name_;
 
   DataType type_;
+  TensorShape shape_;
   xla::ComputationDataHandle value_;
   xla::ComputationDataHandle initial_value_;