}
XlaBuilder::XlaBuilder(const string& computation_name)
- : name_(computation_name) {}
+ : name_(computation_name), unique_id_(GetUniqueId()) {}
XlaBuilder::~XlaBuilder() {}
}
HloComputationProto entry;
- entry.set_name(name_);
{
int64 root_id;
entry.add_instructions()->Swap(&instruction);
}
- const int64 id = GetUniqueId();
- entry.set_id(id);
- XlaComputation computation(id);
+ entry.set_id(unique_id_);
+ entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
+ XlaComputation computation(entry.id());
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
module->set_id(entry.id());
ShapeInference::InferCallShape(operand_shape_ptrs,
/*to_apply=*/called_program_shape));
- // Add called computation.
- instr.add_called_computation_ids(
- computation.proto().entry_computation_id());
- for (const HloComputationProto& e : computation.proto().computations()) {
- embedded_.insert({e.id(), e});
- }
+ AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
});
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferSliceShape(operand_shape, start_indices,
+ limit_indices, strides));
+ for (int i = 0; i < start_indices.size(); i++) {
+ auto* slice_config = instr.add_slice_dimensions();
+ slice_config->set_start(start_indices[i]);
+ slice_config->set_limit(limit_indices[i]);
+ slice_config->set_stride(strides[i]);
+ }
+
+ return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
+ });
}
XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
+ TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
+ GetShape(start_indices));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferDynamicUpdateSliceShape(
+ operand_shape, update_shape, start_indices_shape));
+
+ return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
+ {operand, update, start_indices});
+ });
}
XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
}
XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+
+ DotDimensionNumbers dimension_numbers;
+ dimension_numbers.add_lhs_contracting_dimensions(
+ lhs_shape.dimensions_size() == 1 ? 0 : 1);
+ dimension_numbers.add_rhs_contracting_dimensions(0);
+ return DotGeneral(lhs, rhs, dimension_numbers);
+ });
}
XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
+ TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
+ dimension_numbers));
+ *instr.mutable_dot_dimension_numbers() = dimension_numbers;
+ return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
+ });
}
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferConvertShape(operand_shape, new_element_type));
+ return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
+ });
}
XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
return UnimplementedOp();
}
+XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<XlaOp> parameters,
+ const Shape& shape) {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ // Check the number of parameters per RNG distribution.
+ switch (distribution) {
+ case RandomDistribution::RNG_NORMAL:
+ case RandomDistribution::RNG_UNIFORM:
+ if (parameters.size() != 2) {
+ return InvalidArgument(
+ "RNG distribution (%s) expects 2 parameters, but got %ld",
+ RandomDistribution_Name(distribution).c_str(), parameters.size());
+ }
+ break;
+ default:
+ LOG(FATAL) << "unhandled distribution " << distribution;
+ }
+
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
+ *instr.mutable_shape() = shape;
+
+ instr.set_distribution(distribution);
+
+ return AddInstruction(std::move(instr), HloOpcode::kRng, parameters);
+ });
+}
+
XlaOp XlaBuilder::RngNormal(const XlaOp& mu, const XlaOp& sigma,
const Shape& shape) {
- return UnimplementedOp();
+ return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
}
XlaOp XlaBuilder::RngUniform(const XlaOp& a, const XlaOp& b,
const Shape& shape) {
- return UnimplementedOp();
+ return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
}
XlaOp XlaBuilder::While(const XlaComputation& condition,
const XlaComputation& body, const XlaOp& init) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ // Infer shape.
+ TF_ASSIGN_OR_RETURN(const auto& body_program_shape, body.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
+ condition.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init));
+ TF_ASSIGN_OR_RETURN(
+ *instr.mutable_shape(),
+ ShapeInference::InferWhileShape(condition_program_shape,
+ body_program_shape, init_shape));
+ // Body comes before condition computation in the vector.
+ AddCalledComputation(body, &instr);
+ AddCalledComputation(condition, &instr);
+ return AddInstruction(std::move(instr), HloOpcode::kWhile, {init});
+ });
}
XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
+ TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init_value));
+ TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
+ computation.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferReduceShape(
+ operand_shape, init_shape, dimensions_to_reduce,
+ called_program_shape));
+
+ for (int64 dim : dimensions_to_reduce) {
+ instr.add_dimensions(dim);
+ }
+
+ AddCalledComputation(computation, &instr);
+
+ return AddInstruction(std::move(instr), HloOpcode::kReduce,
+ {operand, init_value});
+ });
}
XlaOp XlaBuilder::ReduceAll(const XlaOp& operand, const XlaOp& init_value,
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
- instr.set_name(StrCat(instr.opcode(), ".", handle));
+ instr.set_name(StrCat(instr.opcode(), ".", unique_id_, ".", handle));
} else {
// Append the handle to make sure the name is unique.
- instr.set_name(StrCat(instr.name(), ".", handle));
+ instr.set_name(StrCat(instr.name(), ".", unique_id_, ".", handle));
}
for (const auto& operand : operands) {
if (operand.builder_ == nullptr) {
return op;
}
+void XlaBuilder::AddCalledComputation(const XlaComputation& computation,
+ HloInstructionProto* instr) {
+ instr->add_called_computation_ids(computation.proto().entry_computation_id());
+ for (const HloComputationProto& e : computation.proto().computations()) {
+ embedded_.insert({e.id(), e});
+ }
+}
+
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
const XlaOp& op) const {
TF_RETURN_IF_ERROR(first_error_);
#include <vector>
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int32>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
auto result_shape = ShapeUtil::MakeShape(S64, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int64>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int64>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int64>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int64>(&builder, 5, {});
}
auto orig_shape = ShapeUtil::MakeShape(S32, {2});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Gt(builder.ConstantR0<int32>(5), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.Reduce(builder.ConstantR1<int32>(2, 1),
builder.ConstantR0<int32>(0),
CreateScalarAddComputation(S32, &builder), {0});
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
auto result_shape = ShapeUtil::MakeShape(PRED, {});
// Create a computation for the condition: run until condition is true.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Ne(builder.ConstantR0<bool>(true), prev);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body: or condition with true.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
- auto result = builder.Or(prev, builder.ConstantR0<bool>(true));
+ builder.Or(prev, builder.ConstantR0<bool>(true));
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.Ne(builder.ConstantR0<bool>(false),
builder.ConstantR0<bool>(true));
- auto result = builder.While(condition, body, init);
+ builder.While(condition, body, init);
ComputeAndCompareR0<bool>(&builder, true, {});
}
Shape result_shape = ShapeUtil::MakeShape(F32, {0});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 15.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>({});
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>({});
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
}
Shape result_shape = ShapeUtil::MakeShape(F32, {8});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 5.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>(8, 0.125f);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>(8, 0.f);
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
// Individual elements with increase by 1/8 each time through the loop, so
// the sum will increase by 1.0. It will first be >15.5 when the elements
Shape result_shape = ShapeUtil::MakeShape(F32, {8});
// Create a computation for the reduction.
- Computation add;
+ XlaComputation add;
{
- ComputationBuilder builder(client_, "add");
+ XlaBuilder builder("add");
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
builder.Add(x, y);
// Create a computation for the condition.
// Repeat until the sum of the result vector is less than 5.5f.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
/*dimensions_to_reduce=*/{0});
- auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ builder.Gt(builder.ConstantR0<float>(15.5f), sum);
condition = builder.Build().ConsumeValueOrDie();
}
// Create a computation for the body.
// Add a constant vector of 1.f to the result vector.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR1<float>(8, 0.125f);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.ConstantR1<float>(8, 0.f);
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
builder.Tuple({result});
// Individual elements with increase by 1/8 each time through the loop, so
// Create a computation for the condition.
// Repeat for N iterations.
const int N = 2;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(N), iteration);
// Create a computation for the body.
// Add 1 to the iteration variable and permute the weights.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto w1 = builder.GetTupleElement(prev, 1);
auto w2 = builder.GetTupleElement(prev, 2);
auto w3 = builder.GetTupleElement(prev, 3);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
auto result = builder.While(condition, body, init);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(N);
auto expected_w1 = Literal::CreateR1<float>({1.0f, 1.0f, 1.0f});
// Create a computation for the condition.
// Repeat for N iterations.
const int N = 2;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(N), iteration);
// Create a computation for the body.
// Add 1 to the iteration variable permute the weights.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto w1 = builder.GetTupleElement(prev, 1);
auto w2 = builder.GetTupleElement(prev, 2);
auto w3 = builder.GetTupleElement(prev, 3);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), w3, w1, w2});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(3, 1.f),
builder.ConstantR1<float>(3, 2.f), builder.ConstantR1<float>(3, 3.f)});
auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3));
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
std::vector<float> expected = {6.f, 6.f, 6.f};
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
}
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto result = builder.While(condition, body, init);
- VLOG(2) << "while = " << ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ VLOG(2) << "while = "
+ << ShapeUtil::HumanString(
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR1<float>(
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
// Create a computation for the body.
// Add 1 to the iteration variable and or the predicate with true
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto pred = builder.GetTupleElement(prev, 1);
auto new_pred = builder.Or(pred, builder.ConstantR0<bool>(true));
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_pred});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple({builder.ConstantR0<int32>(0),
builder.Ne(builder.ConstantR0<bool>(false),
builder.ConstantR0<bool>(true))});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_predicate = Literal::CreateR0<bool>(true);
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
// Create a computation for the body.
// Add 1 to the iteration variable and set the other tuple element to a
// constant.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
- auto result =
- builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
- builder.ConstantR0<int32>(7)});
+ builder.Tuple({builder.Add(iteration, builder.ConstantR0<int32>(1)),
+ builder.ConstantR0<int32>(7)});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR0<int32>(7)});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR0<int32>(7);
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
- Computation body2;
+ XlaComputation body2;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body2, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
const int c1 = 5;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c1));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation condition2;
+ XlaComputation condition2;
const int c2 = 7;
{
- ComputationBuilder builder(client_, "condition2");
+ XlaBuilder builder("condition2");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(c2));
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
auto weights = builder.GetTupleElement(prev, 1);
auto input = builder.ConstantR1<float>(10, 1.f);
auto new_weights = builder.Add(weights, input);
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto while1 = builder.While(condition, body, init);
auto while_result2 = builder.GetTupleElement(while2, 1);
VLOG(2) << "while_result2 = "
<< ShapeUtil::HumanString(
- *builder.GetShape(while_result2).ConsumeValueOrDie());
+ builder.GetShape(while_result2).ConsumeValueOrDie());
auto result = builder.Add(while_result1, while_result2);
VLOG(2) << "result = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
const float sum = c1 + c2;
std::vector<float> expected(10, sum);
ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
// Create a computation for the condition.
// Repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Gt(builder.ConstantR0<int32>(5), iteration);
// Create a computation for the body.
// Add 1 to the iteration variable and add a constant vector of 1.0f to
// the weight variable, both of which are tuple elements.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
// TupleElement 0
auto iteration = builder.GetTupleElement(prev, 0);
// UpdateSlice.
auto out1 = builder.DynamicUpdateSlice(input, update, starts);
- auto result = builder.Tuple({out0, out1});
+ builder.Tuple({out0, out1});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, "while");
+ XlaBuilder builder("while");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
auto result = builder.While(condition, body, init);
VLOG(2) << "while = "
<< ShapeUtil::HumanString(
- *builder.GetShape(result).ConsumeValueOrDie());
+ builder.GetShape(result).ConsumeValueOrDie());
auto expected_counter = Literal::CreateR0<int32>(5);
auto expected_data = Literal::CreateR1<float>(
// Create a computation for the condition: repeat for count iterations.
auto build_condition = [this, v6s32](int count) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto prev = builder.Reshape(
builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}, {1}), {0},
- {});
+ {});
builder.Gt(builder.ConstantR0<int32>(count), prev);
return builder.Build().ConsumeValueOrDie();
};
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, v6s32, "prev");
auto inc = builder.ConcatInDim(
{builder.ConstantR1<int32>({1}),
builder.ConstantR0<int32>(100),
ShapeUtil::MakeShape(S32, {5}))},
0);
- auto result = builder.Add(inc, prev);
+ builder.Add(inc, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
auto while_loop = [this, &body, build_condition](int count) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
- auto result = builder.While(build_condition(count), body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(build_condition(count), body, init);
return builder.Build();
};
auto inner_result_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})});
- Computation inner_condition;
+ XlaComputation inner_condition;
{
- ComputationBuilder builder(client_, "inner_condition");
+ XlaBuilder builder("inner_condition");
auto params = builder.Parameter(0, inner_result_shape, "prev");
auto i = builder.GetTupleElement(params, 0);
builder.Lt(i, builder.ConstantR0<int32>(7));
// Creates a computation for the outer loop condition:
// repeat while result < 30.
- Computation outer_condition;
+ XlaComputation outer_condition;
{
- ComputationBuilder builder(client_, "outer_condition");
+ XlaBuilder builder("outer_condition");
auto prev = builder.Parameter(0, outer_result_shape, "prev");
builder.Lt(prev, builder.ConstantR0<int32>(30));
outer_condition = builder.Build().ConsumeValueOrDie();
// Creates a computation for the inner loop body: add 1 to `i`, and add 2 to
// `result`.
- Computation inner_body;
+ XlaComputation inner_body;
{
- ComputationBuilder builder(client_, "inner_body");
+ XlaBuilder builder("inner_body");
auto params = builder.Parameter(0, inner_result_shape, "prev");
auto i = builder.GetTupleElement(params, 0);
auto result = builder.GetTupleElement(params, 1);
i = builder.Add(builder.ConstantR0<int32>(1), i);
result = builder.Add(builder.ConstantR0<int32>(2), result);
- auto output = builder.Tuple({i, result});
+ builder.Tuple({i, result});
inner_body = builder.Build().ConsumeValueOrDie();
}
// Creates a computation for the outer loop: run the inner loop with i = 0.
- Computation outer_body;
+ XlaComputation outer_body;
{
- ComputationBuilder builder(client_, "outer_body");
+ XlaBuilder builder("outer_body");
auto prev = builder.Parameter(0, outer_result_shape, "prev");
auto init = builder.Tuple({builder.ConstantR0<int32>(0), prev});
auto result = builder.While(inner_condition, inner_body, init);
- auto output = builder.GetTupleElement(result, 1);
+ builder.GetTupleElement(result, 1);
outer_body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(outer_condition, outer_body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(outer_condition, outer_body, init);
ComputeAndCompareR0<int32>(&builder, 42, {});
}
auto result_shape = ShapeUtil::MakeShape(S32, {});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition_callee;
+ XlaComputation condition_callee;
{
- ComputationBuilder builder(client_, "condition_callee");
+ XlaBuilder builder("condition_callee");
auto prev = builder.Parameter(0, result_shape, "prev");
builder.Tuple({builder.Gt(builder.ConstantR0<int32>(5), prev)});
condition_callee = builder.Build().ConsumeValueOrDie();
}
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, result_shape, "prev");
auto result = builder.Call(condition_callee, {prev});
builder.GetTupleElement(result, 0);
}
// Create a computation for the body: add 1 to the result variable.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, result_shape, "prev");
auto input = builder.ConstantR0<int32>(1);
- auto result = builder.Add(input, prev);
+ builder.Add(input, prev);
body = builder.Build().ConsumeValueOrDie();
}
// Create a While node with computations for the condition and the body.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto init = builder.ConstantR0<int32>(0);
- auto result = builder.While(condition, body, init);
- auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ builder.While(condition, body, init);
ComputeAndCompareR0<int32>(&builder, 5, {});
}
{scalar_s32, matrix_shape, matrix_shape, matrix_shape});
// Create a computation for the condition: repeat for 5 iterations.
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto state = builder.Parameter(0, while_shape, "state");
builder.Gt(builder.ConstantR0<int32>(5), builder.GetTupleElement(state, 0));
TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build());
}
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto state = builder.Parameter(0, while_shape, "state");
auto indvar = builder.GetTupleElement(state, 0);
auto input_0 = builder.GetTupleElement(state, 1);
auto input_1 = builder.GetTupleElement(state, 2);
auto output = builder.Tanh(builder.Dot(input_0, input_1));
auto indvar_next = builder.Add(indvar, builder.ConstantR0<int32>(1));
- auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output});
+ builder.Tuple({indvar_next, input_0, input_1, output});
TF_ASSERT_OK_AND_ASSIGN(body, builder.Build());
}
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto matrix_input = builder.Parameter(0, matrix_shape, "matrix");
auto init = builder.Tuple(
{builder.ConstantR0<int32>(0), matrix_input, matrix_input, matrix_input});
// Create while condition computation with 'loop_limit'.
const int32 loop_limit = 100;
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client, "condition");
+ XlaBuilder builder("condition");
auto prev = builder.Parameter(0, loop_state_shape, "prev");
auto iteration = builder.GetTupleElement(prev, 0);
builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
}
// Create while body computation with unit loop increment.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client, "body");
+ XlaBuilder builder("body");
auto prev = builder.Parameter(0, loop_state_shape, "prev");
// TupleElement 0
auto iteration = builder.GetTupleElement(prev, 0);
auto starts = builder.ConstantR1<int32>({0, 0, 0});
// UpdateSlice.
auto out1 = builder.DynamicUpdateSlice(input, update, starts);
- auto result = builder.Tuple({out0, out1});
+ builder.Tuple({out0, out1});
body = builder.Build().ConsumeValueOrDie();
}
// Create a While instruction.
- ComputationBuilder builder(client, "while");
+ XlaBuilder builder("while");
auto zero = builder.ConstantR0<float>(0.0);
auto input = builder.Broadcast(zero, {seq_len, 1024, 1024});
auto init = builder.Tuple({builder.ConstantR0<int32>(0), input});