From 22116459b258d5753aa76410ab6f4d3cbc928a5a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 2 Feb 2018 11:31:06 -0800 Subject: [PATCH] [TF:XLA] Improve/refactor the handling of resource types/shapes. Previously we used an xla::Shape to track the shape of a resource (Variable, TensorArray, Stack) shape. The xla::Shape described how the resource was represented to XLA, e.g., as a (buffer, size) pair for a Stack resource. Instead, separate the TensorFlow abstract shape representation from the XLA shape representation and track it separately. This leads to simpler and more readable code. PiperOrigin-RevId: 184310694 --- tensorflow/compiler/jit/kernels/xla_launch_op.cc | 4 +- tensorflow/compiler/jit/xla_compilation_cache.cc | 11 +- tensorflow/compiler/tf2xla/graph_compiler.cc | 4 +- tensorflow/compiler/tf2xla/kernels/stack_ops.cc | 20 ++- .../compiler/tf2xla/kernels/strided_slice_op.cc | 11 +- .../compiler/tf2xla/kernels/tensor_array_ops.cc | 33 +++-- tensorflow/compiler/tf2xla/kernels/training_ops.cc | 113 ++++++----------- tensorflow/compiler/tf2xla/kernels/variable_ops.cc | 51 ++++---- tensorflow/compiler/tf2xla/kernels/while_op.cc | 33 ++--- tensorflow/compiler/tf2xla/tf2xla.cc | 4 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 133 ++++++++++++++------ tensorflow/compiler/tf2xla/xla_compiler.h | 26 ++-- tensorflow/compiler/tf2xla/xla_compiler_test.cc | 26 ++-- tensorflow/compiler/tf2xla/xla_context.cc | 12 +- tensorflow/compiler/tf2xla/xla_context.h | 10 +- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 30 +++-- tensorflow/compiler/tf2xla/xla_op_kernel.h | 11 +- tensorflow/compiler/tf2xla/xla_resource.cc | 139 +++++++++++++-------- tensorflow/compiler/tf2xla/xla_resource.h | 38 +++--- 19 files changed, 385 insertions(+), 324 deletions(-) diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 17ae2bb..6353149 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -376,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); - TensorShape write_shape; - OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape)); gpu::DeviceMemoryBase buffer = output->buffer({output_num}); @@ -399,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Looks up the owning Tensor by buffer address. OP_REQUIRES_OK( - ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape, + ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape, variable->tensor())); ++output_num; } diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 21d3a54..6d854a9 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.kind = XlaCompiler::Argument::kConstant; arg.type = input.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); + arg.shape = input.shape(); arg.constant_value = input; ++input_num; } @@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args, arg.constant_value = input; } arg.type = input.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); + arg.shape = input.shape(); ++input_num; } @@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args, if (variable_args[variable_id].present) { const Tensor& value = variable_args[variable_id].value; arg.type = value.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape)); + arg.shape = value.shape(); arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since @@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args, // uninitialized variables. arg.initialized = false; arg.type = DT_INVALID; - arg.shape = xla::Shape(); + arg.shape = TensorShape(); } ++input_num; } diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 02215b5..1418d95 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -60,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, for (int i = 0; i < args->size(); ++i) { XlaCompiler::Argument& arg = (*args)[i]; arg.type = ctx->input_type(i); - - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + arg.shape = ctx->InputShape(i); if (arg.type == DT_RESOURCE) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index d77fb76..1a78c7a 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -77,10 +77,8 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder, // Stack has not been initialized. xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetValue( - dtype, - builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()), - builder->ConstantR0(0)}))); + TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); + TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { // Checks the expected shape matches the actual shape. TensorShape actual_shape; @@ -119,8 +117,8 @@ class StackOp : public XlaOpKernel { string name = strings::StrCat("Stack: ", stack_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, - value, &resource)); - resource->set_tensor_array_size(size); + TensorShape(), value, /*tensor_array_size=*/size, + /*tensor_array_gradients=*/{}, &resource)); ctx->SetResourceOutput(0, resource); } @@ -164,11 +162,9 @@ class StackPushOp : public XlaOpKernel { // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - OP_REQUIRES_OK( - ctx, - resource->SetValue( - dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0(1))}))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple( + {b->DynamicUpdateSlice(ta, update, start_indices), + b->Add(index, b->ConstantR0(1))}))); ctx->SetOutput(0, value); } @@ -208,7 +204,7 @@ class StackPopOp : public XlaOpKernel { xla::ComputationDataHandle index = b->GetTupleElement(state, 1); index = b->Sub(index, b->ConstantR0(1)); - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index}))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index f0525a5..91c1694 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -231,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); } void Compile(XlaOpKernelContext* ctx) override { @@ -252,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, &strides_tensor)); - DataType lhs_type; TensorShape lhs_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape)); + xla::ComputationDataHandle lhs; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); const TensorShape rhs_shape = ctx->InputShape(4); @@ -282,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel { " does not match r-value shape ", rhs_shape.DebugString(), ". Automatic broadcasting not yet implemented.")); - xla::ComputationDataHandle lhs; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs)); - xla::ComputationDataHandle rhs = ctx->Input(4); gtl::InlinedVector dimensions_to_reverse; @@ -320,13 +318,14 @@ class StridedSliceAssignOp : public XlaOpKernel { lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); } private: int32 begin_mask_, end_mask_; int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; + DataType dtype_; }; REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 9224072..7cf9b79 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, TF_RET_CHECK(resource->tensor_array_size() >= 0) << resource->name() << " size " << resource->tensor_array_size(); - TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size()); - ta_shape.AppendShape(elem_shape); if (!resource->initialized()) { xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type()); - TF_RETURN_IF_ERROR(resource->SetValue( - dtype, builder->Broadcast(zero, ta_shape.dim_sizes()))); + + TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); + TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { // Checks the elem_shape matches the TensorArray shape. auto shape_or_status = builder->GetShape(resource->value()); @@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, TensorShape shape; TF_RETURN_IF_ERROR( XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + + TensorShape ta_shape; + ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AppendShape(elem_shape); if (ta_shape != shape) { return errors::InvalidArgument( "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", @@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::ComputationBuilder* builder, TensorShape* shape) { - TF_RETURN_IF_ERROR(resource->GetShape(builder, shape)); - if (shape->dims() < 1) { - return errors::InvalidArgument("TensorArray rank must be >= 1"); - } + *shape = resource->shape(); + shape->InsertDim(0, resource->tensor_array_size()); return Status::OK(); } @@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel { // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. xla::ComputationDataHandle value; + TensorShape shape; if (element_shape_.IsFullyDefined()) { - TensorShape shape; CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); @@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel { string name = strings::StrCat("TensorArray: ", tensor_array_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - dtype_, value, &var)); - var->set_tensor_array_size(size); + dtype_, shape, value, /*tensor_array_size=*/size, + /*tensor_array_gradients=*/{}, &var)); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written)); + OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); } @@ -421,7 +421,7 @@ class TensorArrayScatterOp : public XlaOpKernel { } } - OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta)); + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } @@ -525,9 +525,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - OP_REQUIRES_OK( - ctx, resource->SetValue( - dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())))); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( + ta, b->Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 5534d1b..f750f70 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationDataHandle handle; xla::ComputationBuilder* b = ctx->builder(); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + DataType type = ctx->input_type(1); + TensorShape var_shape; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); + + TensorShape alpha_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + + TensorShape delta_shape = ctx->InputShape(2); + OP_REQUIRES( + ctx, var_shape.IsSameSize(delta_shape), + errors::InvalidArgument("var and delta do not have the same shape: ", + var_shape.DebugString(), " vs ", + delta_shape.DebugString())); + handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel { DataType type = ctx->input_type(2); - DataType var_type, accum_type; TensorShape var_shape, accum_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == accum_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyMomentum must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type))); + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel { errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - xla::ComputationDataHandle var, accum; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - xla::ComputationDataHandle lr = ctx->Input(2); xla::ComputationDataHandle grad = ctx->Input(3); xla::ComputationDataHandle momentum = ctx->Input(4); @@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel { DataType type = ctx->input_type(2); - DataType var_type, accum_type; TensorShape var_shape, accum_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == accum_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyAdagrad must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type))); + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, accum; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); xla::ComputationDataHandle lr = ctx->Input(2); xla::ComputationDataHandle grad = ctx->Input(3); @@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType var_type, m_type, v_type; TensorShape var_shape, m_shape, v_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape)); - - OP_REQUIRES( - ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyRMSProp must match: ", - DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ", - DataTypeString(m_type), " vs. ", DataTypeString(v_type))); + xla::ComputationDataHandle var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); TensorShape beta1_power_shape = ctx->InputShape(3); TensorShape beta2_power_shape = ctx->InputShape(4); @@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, m, v; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v)); xla::ComputationDataHandle beta1_power = ctx->Input(3); xla::ComputationDataHandle beta2_power = ctx->Input(4); xla::ComputationDataHandle lr = ctx->Input(5); @@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { DataType type = ctx->input_type(3); - DataType var_type, ms_type, mom_type; TensorShape var_shape, ms_shape, mom_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == ms_type && type == mom_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyRMSProp must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ", - DataTypeString(ms_type), " vs. ", DataTypeString(mom_type))); + xla::ComputationDataHandle var, ms, mom; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); TensorShape lr_shape = ctx->InputShape(3); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), @@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, ms, mom; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom)); xla::ComputationDataHandle lr = ctx->Input(3); xla::ComputationDataHandle rho = ctx->Input(4); xla::ComputationDataHandle momentum = ctx->Input(5); @@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { xla::ComputationBuilder* b = ctx->builder(); - DataType var_type, accum_type, linear_type; TensorShape var_shape, accum_shape, linear_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); - - OP_REQUIRES( - ctx, dtype == var_type && dtype == accum_type && dtype == linear_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyFtrlV2 must match: ", - DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + xla::ComputationDataHandle var, accum, linear; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, errors::InvalidArgument("lr_power is not a scalar: ", lr_power_shape.DebugString())); - xla::ComputationDataHandle var, accum, linear; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); xla::ComputationDataHandle grad = ctx->Input(3); xla::ComputationDataHandle lr = ctx->Input(4); xla::ComputationDataHandle l1 = ctx->Input(5); diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index 68847ae..e4079eb 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -33,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel { public: explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; - bool initialized = ctx->ReadVariableInput(0, &handle).ok(); - ctx->SetOutput(0, ctx->builder()->ConstantR0(initialized)); + XlaResource* variable; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); + ctx->SetOutput(0, + ctx->builder()->ConstantR0(variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); class ReadVariableOp : public XlaOpKernel { public: - explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + void Compile(XlaOpKernelContext* ctx) override { xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK( + ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); ctx->SetOutput(0, handle); } + + private: + DataType dtype_; }; REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); @@ -65,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel { public: explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(1); xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Add(handle, ctx->Input(1)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -79,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel { public: explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(1); xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Sub(handle, ctx->Input(1)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -95,28 +107,19 @@ class ResourceGatherOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); - // Get the shape of the resource tensor. - TensorShape resource_shape; - DataType resource_dtype; - OP_REQUIRES_OK( - ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape)); - - DataType expected_output_dtype = ctx->expected_output_dtype(0); - OP_REQUIRES(ctx, resource_dtype == expected_output_dtype, - errors::InvalidArgument( - "Variable dtype is ", DataTypeString(resource_dtype), - " but expected output dtype is ", - DataTypeString(expected_output_dtype), ".")); + DataType type = ctx->expected_output_dtype(0); + TensorShape resource_shape; xla::ComputationDataHandle resource_handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, + &resource_handle)); auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); DataType index_type = ctx->input_type(1); xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, resource_handle, resource_shape, indices, indices_shape, 0, - resource_dtype, index_type, builder); + ctx, resource_handle, resource_shape, indices, indices_shape, 0, type, + index_type, builder); ctx->SetOutput(0, gather); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 4a711e4..0ff1b65 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -58,9 +58,8 @@ Status MakeXlaCompilerArgumentsFromInputs( } arg.type = resource->type(); - if (arg.initialized) { - TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape)); - } else { + arg.shape = resource->shape(); + if (!arg.initialized) { *has_uninitialized_vars = true; } arg.tensor_array_size = resource->tensor_array_size(); @@ -70,14 +69,13 @@ Status MakeXlaCompilerArgumentsFromInputs( arg.name = resource->name(); VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << xla::ShapeUtil::HumanString(arg.shape) + << " shape: " << arg.shape.DebugString() << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = ctx->input_type(i); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + arg.shape = ctx->InputShape(i); } } return Status::OK(); @@ -154,17 +152,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler::Argument& arg = arguments[update.input_index]; if (!arg.initialized) { VLOG(2) << "Update shape for argument " << update.input_index << " " - << xla::ShapeUtil::HumanString(update.shape); + << update.shape.DebugString(); arg.initialized = true; - xla::Shape shape = update.shape; - if (!update.tensor_array_gradients_accessed.empty()) { - shape = xla::ShapeUtil::GetTupleElementShape(shape, 0); - } - std::unique_ptr zero = - xla::Literal::CreateFromShape(shape); - OP_REQUIRES_OK(ctx, resource->SetValue( - update.type, builder->ConstantLiteral(*zero))); + arg.shape = update.shape; + OP_REQUIRES_OK(ctx, + resource->SetTypeAndShape(update.type, update.shape)); + + OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder)); } // Add any TensorArray gradients touched by the body to the enclosing @@ -182,9 +177,6 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } - - // Recompute the argument shape. - OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape)); } // Recompile the body with the "correct" resource shapes. VLOG(1) << "Recompiling body with corrected resource shapes"; @@ -292,13 +284,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, resource->SetFromPack( arguments[update.input_index].tensor_array_gradients, - builder->GetTupleElement(while_result, pos), - /*reset_initial_values=*/false, builder)); + builder->GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index << " name: " << resource->name() << " modified: " << update.modified << " type: " << DataTypeString(update.type) - << " shape: " << xla::ShapeUtil::HumanString(update.shape); + << " shape: " << update.shape.DebugString(); // Copies the identity of the resource variable from input to output // unchanged, even if the variable was not modified. ctx->op_kernel_context()->set_output( diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 906f229..6051d7d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TensorShape shape; - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 69b2654..c5b4ec5 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types, bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, tensor_array_size, + if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size, tensor_array_gradients) != std::tie(other.kind, other.resource_kind, other.type, other.name, - other.tensor_array_size, other.tensor_array_gradients)) { + other.initialized, other.tensor_array_size, + other.tensor_array_gradients)) { return false; } - if (!xla::ShapeUtil::Equal(shape, other.shape)) { + if (shape != other.shape) { return false; } if (constant_value.shape() != other.constant_value.shape()) { @@ -230,6 +231,64 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, return Status::OK(); } +// Computes the XLA shape for argument 'arg'. +/*static*/ Status XlaCompiler::XLAShapeForArgument( + const XlaCompiler::Argument& arg, xla::Shape* xla_shape) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), + xla_shape); + case XlaCompiler::Argument::kParameter: + return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaCompiler::Argument::kResource: { + TF_RET_CHECK(arg.initialized); + + switch (arg.resource_kind) { + case XlaResource::kVariable: + return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaResource::kTensorArray: { + if (arg.tensor_array_size < 0) { + return errors::InvalidArgument( + "Negative tensor_array_size in XLAShapeForArgument"); + } + TensorShape shape; + shape.AddDim(arg.tensor_array_size); + shape.AppendShape(arg.shape); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); + + if (!arg.tensor_array_gradients.empty()) { + std::vector tuple_shape( + arg.tensor_array_gradients.size() + 1, *xla_shape); + *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape); + } + return Status::OK(); + } + case XlaResource::kStack: { + if (arg.tensor_array_size < 0) { + return errors::InvalidArgument( + "Negative tensor_array_size in XLAShapeForArgument"); + } + TensorShape shape; + shape.AddDim(arg.tensor_array_size); + shape.AppendShape(arg.shape); + xla::Shape buffer_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); + *xla_shape = xla::ShapeUtil::MakeTupleShape( + {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); + return Status::OK(); + } + + case XlaResource::kInvalid: + return errors::Internal( + "Invalid resource type in XLAShapeForArgument()"); + } + } + case XlaCompiler::Argument::kInvalid: + return errors::Internal("Invalid argument type in XLAShapeForArgument()"); + } +} + namespace { Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, @@ -275,8 +334,9 @@ Status BuildArguments(const Graph& graph, // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. - std::vector parameters, resources; - parameters.reserve(args.size()); + input_mapping->clear(); + input_mapping->reserve(args.size()); + std::vector resources; resources.reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. @@ -290,18 +350,20 @@ Status BuildArguments(const Graph& graph, // TODO(phawkins): this code assumes that resource arguments do not // alias. XlaResource* resource; - TF_RETURN_IF_ERROR( - context->CreateResource(arg.resource_kind, i, arg.name, arg.type, - xla::ComputationDataHandle(), &resource)); - resource->set_tensor_array_size(arg.tensor_array_size); + TF_RETURN_IF_ERROR(context->CreateResource( + arg.resource_kind, i, arg.name, arg.type, arg.shape, + xla::ComputationDataHandle(), + /*tensor_array_size=*/arg.tensor_array_size, + /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); if (arg.initialized) { resources.push_back(i); } break; - case XlaCompiler::Argument::kParameter: - parameters.push_back(i); + case XlaCompiler::Argument::kParameter: { + input_mapping->push_back(i); break; + } case XlaCompiler::Argument::kConstant: arg_expression.set_constant_value(arg.constant_value); break; @@ -312,19 +374,17 @@ Status BuildArguments(const Graph& graph, // Append parameters containing variable values after the other runtime // parameters. - parameters.insert(parameters.end(), resources.begin(), resources.end()); - if (parameters.empty()) { + input_mapping->insert(input_mapping->end(), resources.begin(), + resources.end()); + if (input_mapping->empty()) { return Status::OK(); } - std::vector arg_shapes; - arg_shapes.reserve(parameters.size()); - input_mapping->resize(parameters.size()); - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const XlaCompiler::Argument& arg = args[parameters[i]]; + std::vector arg_shapes(input_mapping->size()); + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - arg_shapes.push_back(arg.shape); - (*input_mapping)[i] = parameters[i]; + TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument( + args[(*input_mapping)[i]], &arg_shapes[i])); } if (use_tuple_arg) { @@ -354,13 +414,13 @@ Status BuildArguments(const Graph& graph, } // Build parameter handles for non-constant arguments. - std::vector arg_handles(parameters.size()); + std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { xla::ComputationDataHandle tuple; if (is_entry_computation) { xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); - for (int64 parameter : parameters) { + for (int64 parameter : *input_mapping) { const int core = (*arg_cores)[parameter]; const int root_device = 0; *tuple_sharding.add_tuple_shardings() = @@ -373,16 +433,16 @@ Status BuildArguments(const Graph& graph, } else { tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); } - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const int core = (*arg_cores)[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const int core = (*arg_cores)[input_mapping->at(i)]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); } } else { - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const int core = (*arg_cores)[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const int core = (*arg_cores)[input_mapping->at(i)]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -393,19 +453,18 @@ Status BuildArguments(const Graph& graph, // Fill in the handles in non-constant arguments. VLOG(2) << "XLA computation inputs:"; - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const XlaCompiler::Argument& arg = args[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; VLOG(2) << " XLA arg " << i << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) - << " name: " << arg.name << " TF arg " << parameters[i]; - XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; + << " name: " << arg.name << " TF arg " << input_mapping->at(i); + XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); XlaResource* resource = arg_expression.resource(); - TF_RETURN_IF_ERROR( - resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i], - /*reset_initial_values=*/true, builder)); + TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients, + arg_handles[i], builder)); VLOG(2) << " resource: num_gradients: " << arg.tensor_array_gradients.size(); break; @@ -486,6 +545,7 @@ Status BuildComputation( XlaCompiler::ResourceUpdate& update = resource_updates->back(); update.input_index = resource->arg_num(); update.type = resource->type(); + update.shape = resource->shape(); update.modified = modified; for (const auto& grad : resource->tensor_array_gradients()) { update.tensor_array_gradients_accessed.insert(grad.first); @@ -616,13 +676,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, ++computation_output; } } - - for (std::vector::size_type i = 0; - i < result->resource_updates.size(); ++i) { - result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape( - result->xla_output_shape, computation_output); - ++computation_output; - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 30d3c05..b86c82c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -104,9 +104,17 @@ class XlaCompiler { // is the type of the variable's value, not DT_RESOURCE. DataType type; - // The shape of the argument. If the argument is a resource, this is the - // shape of the resource's value. - xla::Shape shape; + // The shape of the argument. For: + // * a parameter: the shape of the parameter. + // * a constant: ignored; the shape given by constant_value is used + // instead. + // * an uninitialized resource: ignored. We don't yet know the shape of an + // uninitialized resource (otherwise we would have initialized it!) + // * an initialized variable: the shape of the variable's value. + // * an initialized TensorArray or Stack resource: the shape of an entry in + // the TensorArray/Stack. Note this is the size of a single entry, not the + // XLA data structure that represents the complete stack/array. + TensorShape shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -175,8 +183,9 @@ class XlaCompiler { int input_index; // Type and shape of the tensor to be written back. + // The `shape` field has the same meaning as the Argument::shape field. DataType type; - xla::Shape shape; + TensorShape shape; // Was the value of the variable modified by the computation? // (Always true, unless `return_updated_values_for_all_resources` is true.) @@ -266,11 +275,10 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); - Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func, - const std::vector& types, - const std::vector& shapes, - const std::vector& expressions, - std::vector* args); + // Returns the shape of the XLA parameter for an argument 'arg'. + // See the class comment for more details about the argument passing + // convention. + static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 7ebe4b7..65de4db 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -191,10 +191,10 @@ TEST_F(XlaCompilerTest, Simple) { std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[1].shape = TensorShape({2}); // Compiles the graph. XlaCompiler compiler(DefaultOptions()); @@ -242,10 +242,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[1].shape = TensorShape({2}); // Compiles the graph. XlaCompiler compiler(DefaultOptions()); @@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); XlaCompiler::Options options = DefaultOptions(); XlaCompiler compiler(options); @@ -373,7 +373,7 @@ TEST_F(XlaCompilerTest, ResourceManager) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); DummyResourceForTest* resource = new DummyResourceForTest(); @@ -420,7 +420,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); // Compiles the graph. auto options = DefaultOptions(); @@ -472,9 +472,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad2"}; @@ -540,9 +538,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; @@ -574,9 +570,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index e8d17e2..7387895 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -103,12 +103,14 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, - string name, DataType type, - const xla::ComputationDataHandle& handle, - XlaResource** resource) { +Status XlaContext::CreateResource( + XlaResource::Kind kind, int arg_num, string name, DataType type, + TensorShape shape, const xla::ComputationDataHandle& handle, + int64 tensor_array_size, const std::set& tensor_array_gradients, + XlaResource** resource) { resources_.emplace_back( - new XlaResource(kind, arg_num, std::move(name), type, handle)); + new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), + handle, tensor_array_size, tensor_array_gradients)); *resource = resources_.back().get(); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 1a7dafe..fac0352 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -71,11 +71,15 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::Literal& literal); - // Creates a resource with resource `kind` and initial type `type` and - // value `handle`. `name` is a descriptive name for use in error messages. + // Creates a resource with resource `kind` and initial value `handle`. `name` + // is a descriptive name for use in error messages. See the `XlaResource` + // constructor for a description of the remaining arguments. // Fails if the resource already exists. Status CreateResource(XlaResource::Kind kind, int arg_num, string name, - DataType type, const xla::ComputationDataHandle& handle, + DataType type, TensorShape shape, + const xla::ComputationDataHandle& handle, + int64 tensor_array_size, + const std::set& tensor_array_gradients, XlaResource** resource); const std::vector>& resources() { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index ee0aed6..ee29158 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -286,7 +286,8 @@ Status XlaOpKernelContext::ConstantInputList( } Status XlaOpKernelContext::ReadVariableInput( - int index, xla::ComputationDataHandle* value) { + int index, DataType type, TensorShape* shape, + xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); @@ -296,7 +297,15 @@ Status XlaOpKernelContext::ReadVariableInput( return errors::InvalidArgument("Read of uninitialized variable ", variable->name()); } + if (variable->type() != type) { + return errors::InvalidArgument( + "Type mismatch for read of variable ", variable->name(), ". Expected ", + DataTypeString(type), "; got ", DataTypeString(variable->type())); + } *value = variable->value(); + if (shape) { + *shape = variable->shape(); + } return Status::OK(); } @@ -312,12 +321,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, variable->name()); } *type = variable->type(); - auto shape_or_status = builder()->GetShape(variable->value()); - if (!shape_or_status.ok()) { - return shape_or_status.status(); - } - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); + *shape = variable->shape(); return Status::OK(); } @@ -405,7 +409,17 @@ Status XlaOpKernelContext::AssignVariable( XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); - return variable->SetValue(type, handle); + + auto shape_or_status = builder()->GetShape(handle); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + TensorShape shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + + TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); + return variable->SetValue(handle); } XlaCompiler* XlaOpKernelContext::compiler() const { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 6d3b6db..e1fd0f5 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -164,11 +164,16 @@ class XlaOpKernelContext { TensorShape* shape) const; // Reads the current value of the resouce variable referred to by input - // 'index'. - Status ReadVariableInput(int index, xla::ComputationDataHandle* value); + // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the + // variable. Returns an error if the variable has not been initialized, or if + // its type does not match `type`. + Status ReadVariableInput(int index, DataType type, TensorShape* shape, + xla::ComputationDataHandle* value); // Assigns the value `handle` to the variable referenced by input - // `input_index`. Marks the operator as having side effects. + // `input_index`. The variable must be of `type`. Returns an error if the + // variable has been initialized with a different type or with a + // different shape. Status AssignVariable(int input_index, DataType type, const xla::ComputationDataHandle& handle); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 9abac8b..c2075b4 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -25,51 +25,99 @@ limitations under the License. namespace tensorflow { -XlaResource::XlaResource(Kind kind, int arg_num, string name, - DataType initial_type, - const xla::ComputationDataHandle& initial_value) +XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, + const xla::ComputationDataHandle& initial_value, + int64 tensor_array_size, + const std::set& tensor_array_gradients) : kind_(kind), arg_num_(arg_num), name_(std::move(name)), - type_(initial_type), + type_(type), + shape_(std::move(shape)), value_(initial_value), - initial_value_(initial_value) { + initial_value_(initial_value), + tensor_array_size_(tensor_array_size) { CHECK(kind_ != kInvalid); + + for (const string& gradient : tensor_array_gradients) { + tensor_array_gradients_[gradient].reset( + new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + type_, shape_, xla::ComputationDataHandle(), + tensor_array_size_, /*tensor_array_gradients=*/{})); + } } -Status XlaResource::SetValue(DataType type, - const xla::ComputationDataHandle& value) { - if (type_ == DT_INVALID && type == DT_INVALID) { - return errors::InvalidArgument("Attempted to initialized resource ", name_, - " to an invalid type"); +Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { + if (type == DT_INVALID) { + return errors::InvalidArgument("Attempted to set type of resource '", name_, + "'' to an invalid type"); } - if (type_ != DT_INVALID && type_ != type) { + if (initialized() && type_ != type) { return errors::InvalidArgument("Type of resource ", name_, " cannot be changed after initialization: " "old type was ", DataTypeString(type_), ", new type is ", DataTypeString(type)); } + if (initialized() && shape_ != shape) { + return errors::InvalidArgument("Shape of resource ", name_, + " cannot be changed after initialization: " + "old shape was ", + shape_.DebugString(), ", new shape is ", + shape.DebugString()); + } type_ = type; - value_ = value; + shape_ = shape; return Status::OK(); } -Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder, - xla::Shape* shape) const { - auto shape_or_status = builder->GetShape(value_); - if (!shape_or_status.ok()) { - return shape_or_status.status(); +Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { + if (type_ == DT_INVALID) { + return errors::InvalidArgument( + "Resource '", name_, + "' must be initialized with a valid type before use."); } - *shape = *shape_or_status.ValueOrDie(); + value_ = value; return Status::OK(); } -Status XlaResource::GetShape(xla::ComputationBuilder* builder, - TensorShape* shape) const { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape)); - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape)); +Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { + if (type_ == DT_INVALID) { + return errors::InvalidArgument( + "Resource '", name_, + "' must be initialized with a valid type before use."); + } + switch (kind_) { + case kVariable: { + value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), + shape_.dim_sizes()); + break; + } + case kTensorArray: { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()); + break; + } + case kStack: { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + value_ = + builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()), + builder->ConstantR0(0)}); + break; + } + + case kInvalid: + default: + LOG(FATAL) << "Invalid resource type"; + } return Status::OK(); } @@ -82,36 +130,20 @@ Status XlaResource::GetOrCreateTensorArrayGradient( std::unique_ptr& gradient = tensor_array_gradients_[source]; if (!gradient) { TensorShape ta_shape; - TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape)); + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); xla::ComputationDataHandle gradient_value = builder->Broadcast( XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); gradient.reset( new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, /*name=*/strings::StrCat("TensorArrayGrad: ", name_), - type_, gradient_value)); - gradient->tensor_array_size_ = tensor_array_size_; + type_, shape_, gradient_value, tensor_array_size_, + /*tensor_array_gradients=*/{})); } *gradient_out = gradient.get(); return Status::OK(); } -Status XlaResource::PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const { - if (tensor_array_gradients_.empty()) { - return GetXlaShape(builder, packed_shape); - } - TF_RET_CHECK(kind_ == kTensorArray); - std::vector elem_shapes(1 + tensor_array_gradients_.size()); - int pos = 0; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++])); - for (const auto& gradient : tensor_array_gradients_) { - TF_RETURN_IF_ERROR( - gradient.second->GetXlaShape(builder, &elem_shapes[pos++])); - } - *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); - return Status::OK(); -} - Status XlaResource::Pack(xla::ComputationDataHandle* pack, xla::ComputationBuilder* builder) const { if (tensor_array_gradients_.empty()) { @@ -130,27 +162,32 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack, Status XlaResource::SetFromPack(const std::set& gradient_sources, const xla::ComputationDataHandle& pack, - bool reset_initial_values, xla::ComputationBuilder* builder) { if (gradient_sources.empty()) { + if (!initialized()) { + initial_value_ = pack; + } value_ = pack; } else { TF_RET_CHECK(kind_ == kTensorArray); int pos = 0; - value_ = builder->GetTupleElement(pack, pos++); + auto v = builder->GetTupleElement(pack, pos++); + if (!initialized()) { + initial_value_ = v; + } + value_ = v; + for (const auto& source : gradient_sources) { XlaResource* gradient; TF_RETURN_IF_ERROR( GetOrCreateTensorArrayGradient(source, builder, &gradient)); - gradient->value_ = builder->GetTupleElement(pack, pos++); - if (reset_initial_values) { - gradient->initial_value_ = gradient->value_; + auto v = builder->GetTupleElement(pack, pos++); + if (!gradient->initialized()) { + gradient->initial_value_ = v; } + gradient->value_ = v; } } - if (reset_initial_values) { - initial_value_ = value_; - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index 6b46089..1bb2c72 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -36,8 +36,11 @@ class XlaResource { kStack, }; - XlaResource(Kind kind, int arg_num, string name, DataType initial_type, - const xla::ComputationDataHandle& initial_value); + XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, + const xla::ComputationDataHandle& initial_value, + int64 tensor_array_size, + const std::set& tensor_array_gradients); XlaResource(const XlaResource&) = delete; XlaResource(XlaResource&&) = delete; @@ -60,6 +63,12 @@ class XlaResource { // a resource is first initialized we do not yet know its type, so we keep // track of its type dynamically. DataType type() const { return type_; } + + // Shape of the resource. For an uninitialized resource, this is ignored. + // For a Variable, this is the shape of the value. For a TensorArray or Stack + // this is the shape of each entry in the TensorArray/Stack. + const TensorShape& shape() const { return shape_; } + const xla::ComputationDataHandle& value() const { return value_; } // Value of the resource at computation entry. Used to detect which @@ -68,17 +77,19 @@ class XlaResource { return initial_value_; } + // A variable is initialized if it has a value. bool initialized() const { return value_.handle() > 0; } - // Sets the current type/value of the resource. - Status SetValue(DataType type, const xla::ComputationDataHandle& value); + // Sets the type and shape of the resource. The type and shape of a resource + // must not change once the variable has been initialized. + Status SetTypeAndShape(DataType type, const TensorShape& shape); - // Returns the shape of the resource as an xla::Shape. - Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const; + // Sets the current value of the resource. Returns an error if the type is not + // set to a valid value. + Status SetValue(const xla::ComputationDataHandle& value); - // Returns the shape of the resource as an TensorShape. Fails if the shape is - // not representable as a TensorShape. - Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const; + // Sets the current value of the resource to an all-zero value. + Status SetZeroValue(xla::ComputationBuilder* builder); // Looks up the gradient for `source`, or creates it if it does not already // exist. The call target must be an initialized TensorArray resource. A @@ -96,10 +107,6 @@ class XlaResource { Status Pack(xla::ComputationDataHandle* pack, xla::ComputationBuilder* builder) const; - // Returns the shape of the `pack` value computed by `Pack()`. - Status PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const; - // Updates the resource with values from `pack`. If `gradient_sources` is // non-empty, treats `pack` as a tuple that represents a TensorArray and // its gradients, and unpacks and updates the gradient resources. @@ -108,14 +115,14 @@ class XlaResource { // Opposite of Pack(). Status SetFromPack(const std::set& gradient_sources, const xla::ComputationDataHandle& pack, - bool reset_initial_values, xla::ComputationBuilder* builder); - // TensorArray-specific fields + // TensorArray and Stack specific fields // 'tensor_array_size' stores the expected size of the TensorArray or Stack. // We need to store this since sometimes TensorArrays must be initialized // lazily since we do not know the element shape at construction time. + // Used by both TensorArrays and Stacks. int64 tensor_array_size() const { return tensor_array_size_; } void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } @@ -136,6 +143,7 @@ class XlaResource { const string name_; DataType type_; + TensorShape shape_; xla::ComputationDataHandle value_; xla::ComputationDataHandle initial_value_; -- 2.7.4