OP_REQUIRES(ctx,
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));
- TensorShape write_shape;
- OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape));
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
// Looks up the owning Tensor by buffer address.
OP_REQUIRES_OK(
- ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape,
+ ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape,
variable->tensor()));
++output_num;
}
XlaCompiler::Argument& arg = (*args)[input_num];
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input.dtype();
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
+ arg.shape = input.shape();
arg.constant_value = input;
++input_num;
}
arg.constant_value = input;
}
arg.type = input.dtype();
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
+ arg.shape = input.shape();
++input_num;
}
if (variable_args[variable_id].present) {
const Tensor& value = variable_args[variable_id].value;
arg.type = value.dtype();
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape));
+ arg.shape = value.shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since
// uninitialized variables.
arg.initialized = false;
arg.type = DT_INVALID;
- arg.shape = xla::Shape();
+ arg.shape = TensorShape();
}
++input_num;
}
for (int i = 0; i < args->size(); ++i) {
XlaCompiler::Argument& arg = (*args)[i];
arg.type = ctx->input_type(i);
-
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
+ arg.shape = ctx->InputShape(i);
if (arg.type == DT_RESOURCE) {
return errors::InvalidArgument(
// Stack has not been initialized.
xla::ComputationDataHandle zero =
XlaHelpers::Zero(builder, resource->type());
- TF_RETURN_IF_ERROR(resource->SetValue(
- dtype,
- builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()),
- builder->ConstantR0<int32>(0)})));
+ TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
+ TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
// Checks the expected shape matches the actual shape.
TensorShape actual_shape;
string name = strings::StrCat("Stack: ", stack_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
- value, &resource));
- resource->set_tensor_array_size(size);
+ TensorShape(), value, /*tensor_array_size=*/size,
+ /*tensor_array_gradients=*/{}, &resource));
ctx->SetResourceOutput(0, resource);
}
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
- OP_REQUIRES_OK(
- ctx,
- resource->SetValue(
- dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices),
- b->Add(index, b->ConstantR0<int32>(1))})));
+ OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple(
+ {b->DynamicUpdateSlice(ta, update, start_indices),
+ b->Add(index, b->ConstantR0<int32>(1))})));
ctx->SetOutput(0, value);
}
xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
index = b->Sub(index, b->ConstantR0<int32>(1));
- OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index})));
+ OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
&strides_tensor));
- DataType lhs_type;
TensorShape lhs_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
+ xla::ComputationDataHandle lhs;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
const TensorShape rhs_shape = ctx->InputShape(4);
" does not match r-value shape ", rhs_shape.DebugString(),
". Automatic broadcasting not yet implemented."));
- xla::ComputationDataHandle lhs;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
-
xla::ComputationDataHandle rhs = ctx->Input(4);
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
}
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
}
private:
int32 begin_mask_, end_mask_;
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
DataType index_type_;
+ DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
TF_RET_CHECK(resource->tensor_array_size() >= 0)
<< resource->name() << " size " << resource->tensor_array_size();
- TensorShape ta_shape;
- ta_shape.AddDim(resource->tensor_array_size());
- ta_shape.AppendShape(elem_shape);
if (!resource->initialized()) {
xla::ComputationDataHandle zero =
XlaHelpers::Zero(builder, resource->type());
- TF_RETURN_IF_ERROR(resource->SetValue(
- dtype, builder->Broadcast(zero, ta_shape.dim_sizes())));
+
+ TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
+ TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
// Checks the elem_shape matches the TensorArray shape.
auto shape_or_status = builder->GetShape(resource->value());
TensorShape shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+
+ TensorShape ta_shape;
+ ta_shape.AddDim(resource->tensor_array_size());
+ ta_shape.AppendShape(elem_shape);
if (ta_shape != shape) {
return errors::InvalidArgument(
"Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
Status GetTensorArrayShape(const XlaResource* resource,
xla::ComputationBuilder* builder,
TensorShape* shape) {
- TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
- if (shape->dims() < 1) {
- return errors::InvalidArgument("TensorArray rank must be >= 1");
- }
+ *shape = resource->shape();
+ shape->InsertDim(0, resource->tensor_array_size());
return Status::OK();
}
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
xla::ComputationDataHandle value;
+ TensorShape shape;
if (element_shape_.IsFullyDefined()) {
- TensorShape shape;
CHECK(element_shape_.AsTensorShape(&shape));
TensorShape ta_shape;
ta_shape.AddDim(size);
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
- dtype_, value, &var));
- var->set_tensor_array_size(size);
+ dtype_, shape, value, /*tensor_array_size=*/size,
+ /*tensor_array_gradients=*/{}, &var));
ctx->SetResourceOutput(0, var);
Tensor flow(DT_FLOAT, TensorShape({}));
xla::ComputationDataHandle written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
- OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written));
+ OP_REQUIRES_OK(ctx, resource->SetValue(written));
ctx->SetOutput(0, flow);
}
}
}
- OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta));
+ OP_REQUIRES_OK(ctx, resource->SetValue(ta));
ctx->SetOutput(0, flow);
}
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- OP_REQUIRES_OK(
- ctx, resource->SetValue(
- dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()))));
+ OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
+ ta, b->Reshape(value, ta_shape.dim_sizes()))));
ctx->SetOutput(0, flow);
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
xla::ComputationBuilder* b = ctx->builder();
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+ DataType type = ctx->input_type(1);
+ TensorShape var_shape;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
+
+ TensorShape alpha_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha_shape.DebugString()));
+
+ TensorShape delta_shape = ctx->InputShape(2);
+ OP_REQUIRES(
+ ctx, var_shape.IsSameSize(delta_shape),
+ errors::InvalidArgument("var and delta do not have the same shape: ",
+ var_shape.DebugString(), " vs ",
+ delta_shape.DebugString()));
+
handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
DataType type = ctx->input_type(2);
- DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
-
- OP_REQUIRES(
- ctx, type == var_type && type == accum_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyMomentum must match: ",
- DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
- DataTypeString(accum_type)));
+ xla::ComputationDataHandle var, accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
- xla::ComputationDataHandle var, accum;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
-
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
xla::ComputationDataHandle momentum = ctx->Input(4);
DataType type = ctx->input_type(2);
- DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
-
- OP_REQUIRES(
- ctx, type == var_type && type == accum_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyAdagrad must match: ",
- DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
- DataTypeString(accum_type)));
+ xla::ComputationDataHandle var, accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle var, accum;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
}
void Compile(XlaOpKernelContext* ctx) override {
- DataType var_type, m_type, v_type;
TensorShape var_shape, m_shape, v_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape));
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape));
-
- OP_REQUIRES(
- ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyRMSProp must match: ",
- DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ",
- DataTypeString(m_type), " vs. ", DataTypeString(v_type)));
+ xla::ComputationDataHandle var, m, v;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
TensorShape beta1_power_shape = ctx->InputShape(3);
TensorShape beta2_power_shape = ctx->InputShape(4);
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle var, m, v;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v));
xla::ComputationDataHandle beta1_power = ctx->Input(3);
xla::ComputationDataHandle beta2_power = ctx->Input(4);
xla::ComputationDataHandle lr = ctx->Input(5);
DataType type = ctx->input_type(3);
- DataType var_type, ms_type, mom_type;
TensorShape var_shape, ms_shape, mom_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape));
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape));
-
- OP_REQUIRES(
- ctx, type == var_type && type == ms_type && type == mom_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyRMSProp must match: ",
- DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ",
- DataTypeString(ms_type), " vs. ", DataTypeString(mom_type)));
+ xla::ComputationDataHandle var, ms, mom;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
TensorShape lr_shape = ctx->InputShape(3);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle var, ms, mom;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom));
xla::ComputationDataHandle lr = ctx->Input(3);
xla::ComputationDataHandle rho = ctx->Input(4);
xla::ComputationDataHandle momentum = ctx->Input(5);
bool has_l2_shrinkage) {
xla::ComputationBuilder* b = ctx->builder();
- DataType var_type, accum_type, linear_type;
TensorShape var_shape, accum_shape, linear_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape));
-
- OP_REQUIRES(
- ctx, dtype == var_type && dtype == accum_type && dtype == linear_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyFtrlV2 must match: ",
- DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ",
- DataTypeString(accum_type), " and ", DataTypeString(linear_type)));
+ xla::ComputationDataHandle var, accum, linear;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
errors::InvalidArgument("lr_power is not a scalar: ",
lr_power_shape.DebugString()));
- xla::ComputationDataHandle var, accum, linear;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear));
xla::ComputationDataHandle grad = ctx->Input(3);
xla::ComputationDataHandle lr = ctx->Input(4);
xla::ComputationDataHandle l1 = ctx->Input(5);
public:
explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle handle;
- bool initialized = ctx->ReadVariableInput(0, &handle).ok();
- ctx->SetOutput(0, ctx->builder()->ConstantR0<bool>(initialized));
+ XlaResource* variable;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
+ ctx->SetOutput(0,
+ ctx->builder()->ConstantR0<bool>(variable->initialized()));
}
};
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
class ReadVariableOp : public XlaOpKernel {
public:
- explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+ OP_REQUIRES_OK(
+ ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
ctx->SetOutput(0, handle);
}
+
+ private:
+ DataType dtype_;
};
REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
public:
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
+ DataType type = ctx->input_type(1);
xla::ComputationDataHandle handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Add(handle, ctx->Input(1));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
public:
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
+ DataType type = ctx->input_type(1);
xla::ComputationDataHandle handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Sub(handle, ctx->Input(1));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* builder = ctx->builder();
- // Get the shape of the resource tensor.
- TensorShape resource_shape;
- DataType resource_dtype;
- OP_REQUIRES_OK(
- ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape));
-
- DataType expected_output_dtype = ctx->expected_output_dtype(0);
- OP_REQUIRES(ctx, resource_dtype == expected_output_dtype,
- errors::InvalidArgument(
- "Variable dtype is ", DataTypeString(resource_dtype),
- " but expected output dtype is ",
- DataTypeString(expected_output_dtype), "."));
+ DataType type = ctx->expected_output_dtype(0);
+ TensorShape resource_shape;
xla::ComputationDataHandle resource_handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
+ &resource_handle));
auto indices = ctx->Input(1);
auto indices_shape = ctx->InputShape(1);
DataType index_type = ctx->input_type(1);
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
- ctx, resource_handle, resource_shape, indices, indices_shape, 0,
- resource_dtype, index_type, builder);
+ ctx, resource_handle, resource_shape, indices, indices_shape, 0, type,
+ index_type, builder);
ctx->SetOutput(0, gather);
}
};
}
arg.type = resource->type();
- if (arg.initialized) {
- TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape));
- } else {
+ arg.shape = resource->shape();
+ if (!arg.initialized) {
*has_uninitialized_vars = true;
}
arg.tensor_array_size = resource->tensor_array_size();
arg.name = resource->name();
VLOG(2) << " resource " << resource->name()
<< " type: " << DataTypeString(arg.type)
- << " shape: " << xla::ShapeUtil::HumanString(arg.shape)
+ << " shape: " << arg.shape.DebugString()
<< " initialized: " << arg.initialized;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = ctx->input_type(i);
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
+ arg.shape = ctx->InputShape(i);
}
}
return Status::OK();
XlaCompiler::Argument& arg = arguments[update.input_index];
if (!arg.initialized) {
VLOG(2) << "Update shape for argument " << update.input_index << " "
- << xla::ShapeUtil::HumanString(update.shape);
+ << update.shape.DebugString();
arg.initialized = true;
- xla::Shape shape = update.shape;
- if (!update.tensor_array_gradients_accessed.empty()) {
- shape = xla::ShapeUtil::GetTupleElementShape(shape, 0);
- }
- std::unique_ptr<xla::Literal> zero =
- xla::Literal::CreateFromShape(shape);
- OP_REQUIRES_OK(ctx, resource->SetValue(
- update.type, builder->ConstantLiteral(*zero)));
+ arg.shape = update.shape;
+ OP_REQUIRES_OK(ctx,
+ resource->SetTypeAndShape(update.type, update.shape));
+
+ OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
}
// Add any TensorArray gradients touched by the body to the enclosing
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
-
- // Recompute the argument shape.
- OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape));
}
// Recompile the body with the "correct" resource shapes.
VLOG(1) << "Recompiling body with corrected resource shapes";
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
- builder->GetTupleElement(while_result, pos),
- /*reset_initial_values=*/false, builder));
+ builder->GetTupleElement(while_result, pos), builder));
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name() << " modified: " << update.modified
<< " type: " << DataTypeString(update.type)
- << " shape: " << xla::ShapeUtil::HumanString(update.shape);
+ << " shape: " << update.shape.DebugString();
// Copies the identity of the resource variable from input to output
// unchanged, even if the variable was not modified.
ctx->op_kernel_context()->set_output(
XlaCompiler::Argument arg;
arg.kind = XlaCompiler::Argument::kParameter;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
- TensorShape shape;
- TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
- TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
xla_args->push_back(arg);
}
bool XlaCompiler::Argument::operator==(
const XlaCompiler::Argument& other) const {
- if (std::tie(kind, resource_kind, type, name, tensor_array_size,
+ if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
tensor_array_gradients) !=
std::tie(other.kind, other.resource_kind, other.type, other.name,
- other.tensor_array_size, other.tensor_array_gradients)) {
+ other.initialized, other.tensor_array_size,
+ other.tensor_array_gradients)) {
return false;
}
- if (!xla::ShapeUtil::Equal(shape, other.shape)) {
+ if (shape != other.shape) {
return false;
}
if (constant_value.shape() != other.constant_value.shape()) {
return Status::OK();
}
+// Computes the XLA shape for argument 'arg'.
+/*static*/ Status XlaCompiler::XLAShapeForArgument(
+ const XlaCompiler::Argument& arg, xla::Shape* xla_shape) {
+ switch (arg.kind) {
+ case XlaCompiler::Argument::kConstant:
+ return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
+ xla_shape);
+ case XlaCompiler::Argument::kParameter:
+ return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
+ case XlaCompiler::Argument::kResource: {
+ TF_RET_CHECK(arg.initialized);
+
+ switch (arg.resource_kind) {
+ case XlaResource::kVariable:
+ return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
+ case XlaResource::kTensorArray: {
+ if (arg.tensor_array_size < 0) {
+ return errors::InvalidArgument(
+ "Negative tensor_array_size in XLAShapeForArgument");
+ }
+ TensorShape shape;
+ shape.AddDim(arg.tensor_array_size);
+ shape.AppendShape(arg.shape);
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
+
+ if (!arg.tensor_array_gradients.empty()) {
+ std::vector<xla::Shape> tuple_shape(
+ arg.tensor_array_gradients.size() + 1, *xla_shape);
+ *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
+ }
+ return Status::OK();
+ }
+ case XlaResource::kStack: {
+ if (arg.tensor_array_size < 0) {
+ return errors::InvalidArgument(
+ "Negative tensor_array_size in XLAShapeForArgument");
+ }
+ TensorShape shape;
+ shape.AddDim(arg.tensor_array_size);
+ shape.AppendShape(arg.shape);
+ xla::Shape buffer_shape;
+ TF_RETURN_IF_ERROR(
+ TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
+ *xla_shape = xla::ShapeUtil::MakeTupleShape(
+ {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
+ return Status::OK();
+ }
+
+ case XlaResource::kInvalid:
+ return errors::Internal(
+ "Invalid resource type in XLAShapeForArgument()");
+ }
+ }
+ case XlaCompiler::Argument::kInvalid:
+ return errors::Internal("Invalid argument type in XLAShapeForArgument()");
+ }
+}
+
namespace {
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
// Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters.
- std::vector<int> parameters, resources;
- parameters.reserve(args.size());
+ input_mapping->clear();
+ input_mapping->reserve(args.size());
+ std::vector<int> resources;
resources.reserve(args.size());
// Fills in constant arguments, and computes non-constant argument order.
// TODO(phawkins): this code assumes that resource arguments do not
// alias.
XlaResource* resource;
- TF_RETURN_IF_ERROR(
- context->CreateResource(arg.resource_kind, i, arg.name, arg.type,
- xla::ComputationDataHandle(), &resource));
- resource->set_tensor_array_size(arg.tensor_array_size);
+ TF_RETURN_IF_ERROR(context->CreateResource(
+ arg.resource_kind, i, arg.name, arg.type, arg.shape,
+ xla::ComputationDataHandle(),
+ /*tensor_array_size=*/arg.tensor_array_size,
+ /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
arg_expression.set_resource(resource);
if (arg.initialized) {
resources.push_back(i);
}
break;
- case XlaCompiler::Argument::kParameter:
- parameters.push_back(i);
+ case XlaCompiler::Argument::kParameter: {
+ input_mapping->push_back(i);
break;
+ }
case XlaCompiler::Argument::kConstant:
arg_expression.set_constant_value(arg.constant_value);
break;
// Append parameters containing variable values after the other runtime
// parameters.
- parameters.insert(parameters.end(), resources.begin(), resources.end());
- if (parameters.empty()) {
+ input_mapping->insert(input_mapping->end(), resources.begin(),
+ resources.end());
+ if (input_mapping->empty()) {
return Status::OK();
}
- std::vector<xla::Shape> arg_shapes;
- arg_shapes.reserve(parameters.size());
- input_mapping->resize(parameters.size());
- for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
- const XlaCompiler::Argument& arg = args[parameters[i]];
+ std::vector<xla::Shape> arg_shapes(input_mapping->size());
+ for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
// Computes the shapes of non-constant arguments.
- arg_shapes.push_back(arg.shape);
- (*input_mapping)[i] = parameters[i];
+ TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument(
+ args[(*input_mapping)[i]], &arg_shapes[i]));
}
if (use_tuple_arg) {
}
// Build parameter handles for non-constant arguments.
- std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
+ std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
if (use_tuple_arg) {
xla::ComputationDataHandle tuple;
if (is_entry_computation) {
xla::OpSharding tuple_sharding;
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
- for (int64 parameter : parameters) {
+ for (int64 parameter : *input_mapping) {
const int core = (*arg_cores)[parameter];
const int root_device = 0;
*tuple_sharding.add_tuple_shardings() =
} else {
tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
}
- for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
- const int core = (*arg_cores)[parameters[i]];
+ for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+ const int core = (*arg_cores)[input_mapping->at(i)];
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
arg_handles[i] = builder->GetTupleElement(tuple, i);
}
} else {
- for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
- const int core = (*arg_cores)[parameters[i]];
+ for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+ const int core = (*arg_cores)[input_mapping->at(i)];
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
// Fill in the handles in non-constant arguments.
VLOG(2) << "XLA computation inputs:";
- for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
- const XlaCompiler::Argument& arg = args[parameters[i]];
+ for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+ const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
VLOG(2) << " XLA arg " << i
<< " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
- << " name: " << arg.name << " TF arg " << parameters[i];
- XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
+ << " name: " << arg.name << " TF arg " << input_mapping->at(i);
+ XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)];
switch (arg.kind) {
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
XlaResource* resource = arg_expression.resource();
- TF_RETURN_IF_ERROR(
- resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i],
- /*reset_initial_values=*/true, builder));
+ TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
+ arg_handles[i], builder));
VLOG(2) << " resource: num_gradients: "
<< arg.tensor_array_gradients.size();
break;
XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = resource->arg_num();
update.type = resource->type();
+ update.shape = resource->shape();
update.modified = modified;
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
++computation_output;
}
}
-
- for (std::vector<ResourceUpdate>::size_type i = 0;
- i < result->resource_updates.size(); ++i) {
- result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
- result->xla_output_shape, computation_output);
- ++computation_output;
- }
return Status::OK();
}
// is the type of the variable's value, not DT_RESOURCE.
DataType type;
- // The shape of the argument. If the argument is a resource, this is the
- // shape of the resource's value.
- xla::Shape shape;
+ // The shape of the argument. For:
+ // * a parameter: the shape of the parameter.
+ // * a constant: ignored; the shape given by constant_value is used
+ // instead.
+ // * an uninitialized resource: ignored. We don't yet know the shape of an
+ // uninitialized resource (otherwise we would have initialized it!)
+ // * an initialized variable: the shape of the variable's value.
+ // * an initialized TensorArray or Stack resource: the shape of an entry in
+ // the TensorArray/Stack. Note this is the size of a single entry, not the
+ // XLA data structure that represents the complete stack/array.
+ TensorShape shape;
// The value of the argument, if it is a compile-time constant. Must be a
// host-memory tensor.
int input_index;
// Type and shape of the tensor to be written back.
+ // The `shape` field has the same meaning as the Argument::shape field.
DataType type;
- xla::Shape shape;
+ TensorShape shape;
// Was the value of the variable modified by the computation?
// (Always true, unless `return_updated_values_for_all_resources` is true.)
const std::vector<Argument>& args,
CompilationResult* result);
- Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func,
- const std::vector<DataType>& types,
- const std::vector<TensorShape>& shapes,
- const std::vector<const XlaExpression*>& expressions,
- std::vector<Argument>* args);
+ // Returns the shape of the XLA parameter for an argument 'arg'.
+ // See the class comment for more details about the argument passing
+ // convention.
+ static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
- args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
- args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
XlaCompiler::Options options = DefaultOptions();
XlaCompiler compiler(options);
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
DummyResourceForTest* resource = new DummyResourceForTest();
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
// Compiles the graph.
auto options = DefaultOptions();
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::S32, {2}),
- xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad2"};
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::S32, {2}),
- xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::S32, {2}),
- xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
-Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
- string name, DataType type,
- const xla::ComputationDataHandle& handle,
- XlaResource** resource) {
+Status XlaContext::CreateResource(
+ XlaResource::Kind kind, int arg_num, string name, DataType type,
+ TensorShape shape, const xla::ComputationDataHandle& handle,
+ int64 tensor_array_size, const std::set<string>& tensor_array_gradients,
+ XlaResource** resource) {
resources_.emplace_back(
- new XlaResource(kind, arg_num, std::move(name), type, handle));
+ new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
+ handle, tensor_array_size, tensor_array_gradients));
*resource = resources_.back().get();
return Status::OK();
}
Status AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal);
- // Creates a resource with resource `kind` and initial type `type` and
- // value `handle`. `name` is a descriptive name for use in error messages.
+ // Creates a resource with resource `kind` and initial value `handle`. `name`
+ // is a descriptive name for use in error messages. See the `XlaResource`
+ // constructor for a description of the remaining arguments.
// Fails if the resource already exists.
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
- DataType type, const xla::ComputationDataHandle& handle,
+ DataType type, TensorShape shape,
+ const xla::ComputationDataHandle& handle,
+ int64 tensor_array_size,
+ const std::set<string>& tensor_array_gradients,
XlaResource** resource);
const std::vector<std::unique_ptr<XlaResource>>& resources() {
}
Status XlaOpKernelContext::ReadVariableInput(
- int index, xla::ComputationDataHandle* value) {
+ int index, DataType type, TensorShape* shape,
+ xla::ComputationDataHandle* value) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name());
}
+ if (variable->type() != type) {
+ return errors::InvalidArgument(
+ "Type mismatch for read of variable ", variable->name(), ". Expected ",
+ DataTypeString(type), "; got ", DataTypeString(variable->type()));
+ }
*value = variable->value();
+ if (shape) {
+ *shape = variable->shape();
+ }
return Status::OK();
}
variable->name());
}
*type = variable->type();
- auto shape_or_status = builder()->GetShape(variable->value());
- if (!shape_or_status.ok()) {
- return shape_or_status.status();
- }
- TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape));
+ *shape = variable->shape();
return Status::OK();
}
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
- return variable->SetValue(type, handle);
+
+ auto shape_or_status = builder()->GetShape(handle);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ TensorShape shape;
+ TF_RETURN_IF_ERROR(
+ XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+
+ TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
+ return variable->SetValue(handle);
}
XlaCompiler* XlaOpKernelContext::compiler() const {
TensorShape* shape) const;
// Reads the current value of the resouce variable referred to by input
- // 'index'.
- Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
+ // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
+ // variable. Returns an error if the variable has not been initialized, or if
+ // its type does not match `type`.
+ Status ReadVariableInput(int index, DataType type, TensorShape* shape,
+ xla::ComputationDataHandle* value);
// Assigns the value `handle` to the variable referenced by input
- // `input_index`. Marks the operator as having side effects.
+ // `input_index`. The variable must be of `type`. Returns an error if the
+ // variable has been initialized with a different type or with a
+ // different shape.
Status AssignVariable(int input_index, DataType type,
const xla::ComputationDataHandle& handle);
namespace tensorflow {
-XlaResource::XlaResource(Kind kind, int arg_num, string name,
- DataType initial_type,
- const xla::ComputationDataHandle& initial_value)
+XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
+ TensorShape shape,
+ const xla::ComputationDataHandle& initial_value,
+ int64 tensor_array_size,
+ const std::set<string>& tensor_array_gradients)
: kind_(kind),
arg_num_(arg_num),
name_(std::move(name)),
- type_(initial_type),
+ type_(type),
+ shape_(std::move(shape)),
value_(initial_value),
- initial_value_(initial_value) {
+ initial_value_(initial_value),
+ tensor_array_size_(tensor_array_size) {
CHECK(kind_ != kInvalid);
+
+ for (const string& gradient : tensor_array_gradients) {
+ tensor_array_gradients_[gradient].reset(
+ new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
+ /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
+ type_, shape_, xla::ComputationDataHandle(),
+ tensor_array_size_, /*tensor_array_gradients=*/{}));
+ }
}
-Status XlaResource::SetValue(DataType type,
- const xla::ComputationDataHandle& value) {
- if (type_ == DT_INVALID && type == DT_INVALID) {
- return errors::InvalidArgument("Attempted to initialized resource ", name_,
- " to an invalid type");
+Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
+ if (type == DT_INVALID) {
+ return errors::InvalidArgument("Attempted to set type of resource '", name_,
+ "'' to an invalid type");
}
- if (type_ != DT_INVALID && type_ != type) {
+ if (initialized() && type_ != type) {
return errors::InvalidArgument("Type of resource ", name_,
" cannot be changed after initialization: "
"old type was ",
DataTypeString(type_), ", new type is ",
DataTypeString(type));
}
+ if (initialized() && shape_ != shape) {
+ return errors::InvalidArgument("Shape of resource ", name_,
+ " cannot be changed after initialization: "
+ "old shape was ",
+ shape_.DebugString(), ", new shape is ",
+ shape.DebugString());
+ }
type_ = type;
- value_ = value;
+ shape_ = shape;
return Status::OK();
}
-Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder,
- xla::Shape* shape) const {
- auto shape_or_status = builder->GetShape(value_);
- if (!shape_or_status.ok()) {
- return shape_or_status.status();
+Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
+ if (type_ == DT_INVALID) {
+ return errors::InvalidArgument(
+ "Resource '", name_,
+ "' must be initialized with a valid type before use.");
}
- *shape = *shape_or_status.ValueOrDie();
+ value_ = value;
return Status::OK();
}
-Status XlaResource::GetShape(xla::ComputationBuilder* builder,
- TensorShape* shape) const {
- xla::Shape xla_shape;
- TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape));
- TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape));
+Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
+ if (type_ == DT_INVALID) {
+ return errors::InvalidArgument(
+ "Resource '", name_,
+ "' must be initialized with a valid type before use.");
+ }
+ switch (kind_) {
+ case kVariable: {
+ value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
+ shape_.dim_sizes());
+ break;
+ }
+ case kTensorArray: {
+ TensorShape ta_shape;
+ ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AppendShape(shape_);
+ value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
+ ta_shape.dim_sizes());
+ break;
+ }
+ case kStack: {
+ TensorShape ta_shape;
+ ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AppendShape(shape_);
+ value_ =
+ builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
+ ta_shape.dim_sizes()),
+ builder->ConstantR0<int32>(0)});
+ break;
+ }
+
+ case kInvalid:
+ default:
+ LOG(FATAL) << "Invalid resource type";
+ }
return Status::OK();
}
std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
if (!gradient) {
TensorShape ta_shape;
- TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape));
+ ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AppendShape(shape_);
xla::ComputationDataHandle gradient_value = builder->Broadcast(
XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
gradient.reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
- type_, gradient_value));
- gradient->tensor_array_size_ = tensor_array_size_;
+ type_, shape_, gradient_value, tensor_array_size_,
+ /*tensor_array_gradients=*/{}));
}
*gradient_out = gradient.get();
return Status::OK();
}
-Status XlaResource::PackedShape(xla::ComputationBuilder* builder,
- xla::Shape* packed_shape) const {
- if (tensor_array_gradients_.empty()) {
- return GetXlaShape(builder, packed_shape);
- }
- TF_RET_CHECK(kind_ == kTensorArray);
- std::vector<xla::Shape> elem_shapes(1 + tensor_array_gradients_.size());
- int pos = 0;
- TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++]));
- for (const auto& gradient : tensor_array_gradients_) {
- TF_RETURN_IF_ERROR(
- gradient.second->GetXlaShape(builder, &elem_shapes[pos++]));
- }
- *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
- return Status::OK();
-}
-
Status XlaResource::Pack(xla::ComputationDataHandle* pack,
xla::ComputationBuilder* builder) const {
if (tensor_array_gradients_.empty()) {
Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
const xla::ComputationDataHandle& pack,
- bool reset_initial_values,
xla::ComputationBuilder* builder) {
if (gradient_sources.empty()) {
+ if (!initialized()) {
+ initial_value_ = pack;
+ }
value_ = pack;
} else {
TF_RET_CHECK(kind_ == kTensorArray);
int pos = 0;
- value_ = builder->GetTupleElement(pack, pos++);
+ auto v = builder->GetTupleElement(pack, pos++);
+ if (!initialized()) {
+ initial_value_ = v;
+ }
+ value_ = v;
+
for (const auto& source : gradient_sources) {
XlaResource* gradient;
TF_RETURN_IF_ERROR(
GetOrCreateTensorArrayGradient(source, builder, &gradient));
- gradient->value_ = builder->GetTupleElement(pack, pos++);
- if (reset_initial_values) {
- gradient->initial_value_ = gradient->value_;
+ auto v = builder->GetTupleElement(pack, pos++);
+ if (!gradient->initialized()) {
+ gradient->initial_value_ = v;
}
+ gradient->value_ = v;
}
}
- if (reset_initial_values) {
- initial_value_ = value_;
- }
return Status::OK();
}
kStack,
};
- XlaResource(Kind kind, int arg_num, string name, DataType initial_type,
- const xla::ComputationDataHandle& initial_value);
+ XlaResource(Kind kind, int arg_num, string name, DataType type,
+ TensorShape shape,
+ const xla::ComputationDataHandle& initial_value,
+ int64 tensor_array_size,
+ const std::set<string>& tensor_array_gradients);
XlaResource(const XlaResource&) = delete;
XlaResource(XlaResource&&) = delete;
// a resource is first initialized we do not yet know its type, so we keep
// track of its type dynamically.
DataType type() const { return type_; }
+
+ // Shape of the resource. For an uninitialized resource, this is ignored.
+ // For a Variable, this is the shape of the value. For a TensorArray or Stack
+ // this is the shape of each entry in the TensorArray/Stack.
+ const TensorShape& shape() const { return shape_; }
+
const xla::ComputationDataHandle& value() const { return value_; }
// Value of the resource at computation entry. Used to detect which
return initial_value_;
}
+ // A variable is initialized if it has a value.
bool initialized() const { return value_.handle() > 0; }
- // Sets the current type/value of the resource.
- Status SetValue(DataType type, const xla::ComputationDataHandle& value);
+ // Sets the type and shape of the resource. The type and shape of a resource
+ // must not change once the variable has been initialized.
+ Status SetTypeAndShape(DataType type, const TensorShape& shape);
- // Returns the shape of the resource as an xla::Shape.
- Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const;
+ // Sets the current value of the resource. Returns an error if the type is not
+ // set to a valid value.
+ Status SetValue(const xla::ComputationDataHandle& value);
- // Returns the shape of the resource as an TensorShape. Fails if the shape is
- // not representable as a TensorShape.
- Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const;
+ // Sets the current value of the resource to an all-zero value.
+ Status SetZeroValue(xla::ComputationBuilder* builder);
// Looks up the gradient for `source`, or creates it if it does not already
// exist. The call target must be an initialized TensorArray resource. A
Status Pack(xla::ComputationDataHandle* pack,
xla::ComputationBuilder* builder) const;
- // Returns the shape of the `pack` value computed by `Pack()`.
- Status PackedShape(xla::ComputationBuilder* builder,
- xla::Shape* packed_shape) const;
-
// Updates the resource with values from `pack`. If `gradient_sources` is
// non-empty, treats `pack` as a tuple that represents a TensorArray and
// its gradients, and unpacks and updates the gradient resources.
// Opposite of Pack().
Status SetFromPack(const std::set<string>& gradient_sources,
const xla::ComputationDataHandle& pack,
- bool reset_initial_values,
xla::ComputationBuilder* builder);
- // TensorArray-specific fields
+ // TensorArray and Stack specific fields
// 'tensor_array_size' stores the expected size of the TensorArray or Stack.
// We need to store this since sometimes TensorArrays must be initialized
// lazily since we do not know the element shape at construction time.
+ // Used by both TensorArrays and Stacks.
int64 tensor_array_size() const { return tensor_array_size_; }
void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
const string name_;
DataType type_;
+ TensorShape shape_;
xla::ComputationDataHandle value_;
xla::ComputationDataHandle initial_value_;