#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
// Tests a tuple-shaped constant.
XLA_TEST_F(TupleTest, TupleConstant) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
Literal::CreateR1<float>(constant_vector).get(),
Literal::CreateR2<float>(constant_matrix).get()});
- auto result = builder.ConstantLiteral(*value);
+ builder.ConstantLiteral(*value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
// Tests a tuple made of scalar constants.
XLA_TEST_F(TupleTest, TupleScalarConstant) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar1 = 7.3f;
const float constant_scalar2 = 1.2f;
Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar1).get(),
Literal::CreateR0<float>(constant_scalar2).get()});
- auto result = builder.ConstantLiteral(*value);
+ builder.ConstantLiteral(*value);
ComputeAndCompareTuple(&builder, *value, {}, error_spec_);
}
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreate) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const float constant_scalar = 7.3f;
std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
{1.1f, 2.2f, 3.5f}, // row 0
{4.8f, 5.0f, 6.7f}, // row 1
};
- auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar),
- builder.ConstantR1<float>(constant_vector),
- builder.ConstantR2<float>(constant_matrix)});
+ builder.Tuple({builder.ConstantR0<float>(constant_scalar),
+ builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
auto expected =
Literal::MakeTuple({Literal::CreateR0<float>(constant_scalar).get(),
// Tests the creation of tuple data.
XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
- auto result = builder.Tuple(
+ builder.Tuple(
{builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});
auto expected = Literal::MakeTuple({Literal::CreateR0<float>(7.0).get(),
// Tests the creation of an empty tuple.
XLA_TEST_F(TupleTest, EmptyTupleCreate) {
- ComputationBuilder builder(client_, TestName());
- auto result = builder.Tuple({});
+ XlaBuilder builder(TestName());
+ builder.Tuple({});
auto expected = Literal::MakeTuple({});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElement) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
};
auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
builder.ConstantR2<float>(constant_matrix)});
- auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ builder.GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
error_spec_);
}
// Trivial test for extracting a tuple element with GetTupleElement.
XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto tuple_data = builder.Tuple(
{builder.ConstantR1<float>({}),
builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
- auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ builder.GetTupleElement(tuple_data, 1);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNonTupleFailsGracefully) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto value = builder.ConstantR1<float>({4.5f});
builder.GetTupleElement(value, 1);
auto result_status = builder.Build();
// Extracts both elements from a tuple with GetTupleElement and then adds them
// together.
XLA_TEST_F(TupleTest, AddTupleElements) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
auto matrix_element = builder.GetTupleElement(tuple_data, 1);
auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
- auto result = builder.Add(matrix_element, vector_element,
- /*broadcast_dimensions=*/{1});
+ builder.Add(matrix_element, vector_element,
+ /*broadcast_dimensions=*/{1});
Array2D<float> expected({
{2.f, 4.f, 6.f}, // row 0
{5.f, 7.f, 9.f}, // row 1
});
- ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3}));
- ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3}));
+ ASSERT_TRUE(ShapeUtil::ShapeIs(vector_shape, F32, {3}));
+ ASSERT_TRUE(ShapeUtil::ShapeIs(matrix_shape, F32, {/*y=*/2, /*x=*/3}));
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
// Extracts both elements from a tuple and then puts them into a new tuple in
// the opposite order.
XLA_TEST_F(TupleTest, TupleGTEToTuple) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
};
auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
builder.ConstantR2<float>(constant_matrix)});
- auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
- builder.GetTupleElement(tuple_data, 0)});
+ builder.Tuple({builder.GetTupleElement(tuple_data, 1),
+ builder.GetTupleElement(tuple_data, 0)});
auto expected =
Literal::MakeTuple({Literal::CreateR2<float>(constant_matrix).get(),
Literal::CreateR1<float>(constant_vector).get()});
}
XLA_TEST_F(TupleTest, SelectBetweenPredTuples) {
- ComputationBuilder b(client_, TestName());
- ComputationDataHandle v1, v2;
+ XlaBuilder b(TestName());
+ XlaOp v1, v2;
for (bool direction : {false, true}) {
std::unique_ptr<GlobalData> v1_data =
auto v2_gt = b.Gt(v2, v1); // true
auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true}
auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false}
- auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
+ b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1);
auto expected =
Literal::MakeTuple({Literal::CreateR0<bool>(direction).get(),
Literal::CreateR0<bool>(!direction).get()});
// \ (tuple10)-- /
// \ / \ /
// -----(GTE 0)-- --(GTE 1)----------
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
std::initializer_list<std::initializer_list<float>> constant_matrix = {
{1.f, 2.f, 3.f}, // row 0
auto addvectors = builder.Add(vector_from_01, vector_from_10);
auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);
- auto result = builder.Add(addmatrices, addvectors,
- /*broadcast_dimensions=*/{1});
+ builder.Add(addmatrices, addvectors,
+ /*broadcast_dimensions=*/{1});
Array2D<float> expected({
{4.f, 8.f, 12.f}, // row 0
XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
// Tests a selection between tuples with "false" path taken.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto tuple21 = builder.Tuple(
{builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
- auto select =
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
Literal::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
// Tests a selection between tuples with "true" path taken.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto tuple21 = builder.Tuple(
{builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
- auto select =
- builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec1).get(),
Literal::CreateR1<float>(vec2).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
// Tests a selection between tuples but the final result is an element of the
// tuple, not the whole tuple.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto select =
builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
- auto element = builder.GetTupleElement(select, 0);
+ builder.GetTupleElement(select, 0);
ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
}
// / --(GTE 1)--
// /
// (tuple 21)
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
auto select2 =
builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
- auto result = builder.Add(builder.GetTupleElement(select2, 0),
- builder.GetTupleElement(select2, 1));
+ builder.Add(builder.GetTupleElement(select2, 0),
+ builder.GetTupleElement(select2, 1));
ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
}
DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
// Similar to SelectBetweenTuples, but the constants are shared between the
// input tuples.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
auto tuple12 = builder.Tuple({c1, c2});
auto tuple21 = builder.Tuple({c2, c1});
- auto select =
- builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+
auto expected = Literal::MakeTuple({Literal::CreateR1<float>(vec2).get(),
Literal::CreateR1<float>(vec1).get()});
ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
}
XLA_TEST_F(TupleTest, NestedTuples) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto inner_tuple = builder.Tuple(
{builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
- auto outer_tuple =
- builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
+ builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
auto expected_v1 = Literal::CreateR1<float>({1.0, 2.0});
auto expected_s = Literal::CreateR0<float>(42.0);
}
XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape data_shape = ShapeUtil::MakeShape(F32, {3});
Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
}
XLA_TEST_F(TupleTest, ComplexTuples) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
{
Shape c64r0 = ShapeUtil::MakeShape(C64, {});
Shape c64r1 = ShapeUtil::MakeShape(C64, {2});