return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
-template <typename BuilderT>
StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
- BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
// Build the computation, as a convenience.
TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
return client_->Execute(computation, arguments, &execution_options_);
}
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_output_layout) {
- ExecutionOptions execution_options = execution_options_;
- if (shape_with_output_layout != nullptr) {
- *execution_options.mutable_shape_with_output_layout() =
- *shape_with_output_layout;
- }
- return client_->ExecuteAndTransfer(computation, arguments,
- &execution_options);
-}
-
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
const XlaComputation& computation,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_output_layout) {
&execution_options);
}
-template <>
-StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_output_layout) {
- // Build the computation, as a convenience.
- TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
- return ExecuteAndTransfer(computation, arguments, shape_with_output_layout);
-}
-
-template <>
StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_output_layout) {
&execution_options);
}
-std::unique_ptr<GlobalData> ClientLibraryTestBase::ExecuteOrDie(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- return Execute(builder, arguments).ConsumeValueOrDie();
-}
-
-std::unique_ptr<Literal> ClientLibraryTestBase::ExecuteAndTransferOrDie(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie();
-}
-
string ClientLibraryTestBase::ExecuteToString(
XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
auto computation_status = builder->Build();
}
}
-string ClientLibraryTestBase::ExecuteToString(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- auto computation_status = builder->Build();
- if (!computation_status.ok()) {
- return computation_status.status().ToString();
- }
- auto computation = computation_status.ConsumeValueOrDie();
-
- auto result =
- client_->ExecuteAndTransfer(computation, arguments, &execution_options_);
- if (!result.ok()) {
- return result.status().ToString();
- } else {
- return result.ValueOrDie()->ToString();
- }
-}
-
-void ClientLibraryTestBase::ComputeAndCompareR1(
- ComputationBuilder* builder, const tensorflow::core::Bitmap& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
- std::unique_ptr<Literal> expected_literal = Literal::CreateR1(expected);
- ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
- arguments);
-}
-
void ClientLibraryTestBase::ComputeAndCompareR1(
XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
arguments);
}
-template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareLiteral(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_layout) {
EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
shape_with_layout));
}
-template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareLiteral(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout) {
EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
tensorflow::Status
ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
- const xla::Computation& computation, const Literal& expected,
+ const xla::XlaComputation& computation, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output) {
tensorflow::Status
ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
- const xla::Computation& computation, const Literal& expected,
+ const xla::XlaComputation& computation, const Literal& /*expected*/,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const std::function<void(const Literal& actual,
const string& error_message)>& verify_output,
return choose(0);
}
-tensorflow::Status
-ClientLibraryTestBase::ComputeAndCompareLiteralWithAllOutputLayouts(
- const xla::XlaComputation& /*computation*/, const Literal& /*expected*/,
- tensorflow::gtl::ArraySlice<GlobalData*> /*arguments*/,
- const std::function<void(const Literal& actual,
- const string& error_message)>& /*verify_output*/) {
- return Unimplemented("not yet implemented for XlaComputation");
-}
-
-tensorflow::Status
-ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts(
- const xla::XlaComputation& /*computation*/, const Literal& /*expected*/,
- tensorflow::gtl::ArraySlice<GlobalData*> /*arguments*/,
- const std::function<void(const Literal& actual,
- const string& error_message)>& /*verify_output*/,
- const Shape* /*output_with_layout*/) {
- return Unimplemented("not yet implemented for XlaComputation");
-}
-
-template <typename BuilderT>
tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
return tensorflow::Status::OK();
}
-template <typename BuilderT>
tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments_passed_in,
ErrorSpec error, const Shape* shape_with_layout) {
std::vector<GlobalData*> arguments(arguments_passed_in.begin(),
EXPECT_EQ(expected, actual->GetR1U8AsString());
}
-template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareTuple(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
LiteralTestUtil::ExpectEqual(expected, *actual);
}
-template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareTuple(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
}
void ClientLibraryTestBase::ComputeAndCompare(
- ComputationBuilder* builder, const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments) {
- auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
- EXPECT_IS_OK(status_or_data);
- if (!status_or_data.ok()) {
- return;
- }
- std::unique_ptr<Literal> reference, result;
- std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(*reference, *result);
-}
-
-void ClientLibraryTestBase::ComputeAndCompare(
- ComputationBuilder* builder, const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments, ErrorSpec error) {
- auto status_or_data = ComputeValueAndReference(builder, operand, arguments);
- EXPECT_IS_OK(status_or_data);
- if (!status_or_data.ok()) {
- return;
- }
- std::unique_ptr<Literal> reference, result;
- std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(*reference, *result, error);
-}
-
-StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
-ClientLibraryTestBase::ComputeValueAndReference(
- ComputationBuilder* builder, const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments) {
- // Transfer the arguments to the executor service. We put the unique_ptr's
- // into a vector to keep the data alive on the service until the end of this
- // function.
- std::vector<std::unique_ptr<GlobalData>> argument_data;
- for (const auto& arg : arguments) {
- TF_ASSIGN_OR_RETURN(auto data, client_->TransferToServer(arg));
- argument_data.push_back(std::move(data));
- }
-
- // Create raw pointers to the GlobalData for the rest of the call stack.
- std::vector<GlobalData*> argument_data_ptr;
- std::transform(
- argument_data.begin(), argument_data.end(),
- std::back_inserter(argument_data_ptr),
- [](const std::unique_ptr<GlobalData>& data) { return data.get(); });
-
- TF_ASSIGN_OR_RETURN(
- auto reference,
- builder->ComputeConstant(operand, /*output_layout=*/nullptr, arguments));
- TF_ASSIGN_OR_RETURN(auto result,
- ExecuteAndTransfer(builder, argument_data_ptr));
- return std::make_pair(std::move(reference), std::move(result));
-}
-
-void ClientLibraryTestBase::ComputeAndCompare(
XlaBuilder* builder, tensorflow::gtl::ArraySlice<Literal> arguments) {
auto status_or_data = ComputeValueAndReference(builder, arguments);
EXPECT_IS_OK(status_or_data);
return computation_status.ConsumeValueOrDie();
}
-Computation ClientLibraryTestBase::CreateScalarReluSensitivity() {
- ComputationBuilder builder(client_, "relu_sensitivity");
+XlaComputation ClientLibraryTestBase::CreateScalarReluSensitivity() {
+ XlaBuilder builder("relu_sensitivity");
auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
auto activation = builder.Parameter(0, shape, "activation");
auto backprop = builder.Parameter(1, shape, "backprop");
return array;
}
-ComputationDataHandle ClientLibraryTestBase::AddParam(
- const Literal& argument, ComputationBuilder* builder) {
- ComputationDataHandle data_handle;
- arguments_.push_back(CreateParameterAndTransferLiteral(
- arguments_.size(), argument, "", builder, &data_handle));
- return data_handle;
-}
-
XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
XlaBuilder* builder) {
XlaOp data_handle;
return data_handle;
}
-ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral(
- const Literal& literal, ComputationBuilder* builder) {
- return builder->ConstantLiteral(
- use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
-}
-
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return builder->ConstantLiteral(
use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
}
-template void ClientLibraryTestBase::ComputeAndCompareLiteral(
- ComputationBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_layout);
-
-template void ClientLibraryTestBase::ComputeAndCompareLiteral(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_layout);
-
-template void ClientLibraryTestBase::ComputeAndCompareLiteral(
- ComputationBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
- const Shape* shape_with_layout);
-
-template void ClientLibraryTestBase::ComputeAndCompareLiteral(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
- const Shape* shape_with_layout);
-
-template void ClientLibraryTestBase::ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
-template void ClientLibraryTestBase::ComputeAndCompareTuple(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
-template void ClientLibraryTestBase::ComputeAndCompareTuple(
- ComputationBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
-
-template void ClientLibraryTestBase::ComputeAndCompareTuple(
- XlaBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
-
-template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
-template StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
- XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
} // namespace xla
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/statusor.h"
// Convenience methods for building and running a computation with the member
// execution options. Modify execution_options_ in your test if you want to
// customize the options.
- template <typename BuilderT>
StatusOr<std::unique_ptr<GlobalData>> Execute(
- BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- // TODO(b/74197823): Remove the template type 'BuilderT' in all methods once
- // the migration to XlaBuilder is complete.
-
- template <typename BuilderT>
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
- BuilderT* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const Shape* shape_with_output_layout = nullptr);
-
- StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ XlaBuilder* builder, tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_output_layout = nullptr);
StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_output_layout = nullptr);
- // Convenience OrDie variants of above methods.
- std::unique_ptr<GlobalData> ExecuteOrDie(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- std::unique_ptr<Literal> ExecuteAndTransferOrDie(
- ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
-
// Run a computation and return its value as a string. If an error
// occurs, then instead return the error as a string.
string ExecuteToString(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- string ExecuteToString(ComputationBuilder* builder,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
// Convenience methods for building and running a computation, transferring
// the result, and comparing it to the expected value(s). Methods are
// templated on the native host type which maps to specific XLA types (See
- // ComputationBuilder/XlaBuilder for details). For each rank, two forms are
+ // XlaBuilder for details). For each rank, two forms are
// provided: one for floating point types with an ErrorSpec parameter, and one
// for integral types without the ErrorSpec parameter.
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR0(BuilderT* builder, NativeT expected,
+ template <typename NativeT>
+ void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR0(BuilderT* builder, NativeT expected,
+ template <typename NativeT>
+ void ComputeAndCompareR0(XlaBuilder* builder, NativeT expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
ErrorSpec error);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR1(BuilderT* builder,
+ template <typename NativeT>
+ void ComputeAndCompareR1(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR1(BuilderT* builder,
+ template <typename NativeT>
+ void ComputeAndCompareR1(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
ErrorSpec error);
// As above, but uses a bitmap to hold the predicate vector to avoid
// deficiencies of vector<bool>.
- void ComputeAndCompareR1(ComputationBuilder* builder,
- const tensorflow::core::Bitmap& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments);
void ComputeAndCompareR1(XlaBuilder* builder,
const tensorflow::core::Bitmap& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR2(BuilderT* builder, const Array2D<NativeT>& expected,
+ template <typename NativeT>
+ void ComputeAndCompareR2(XlaBuilder* builder,
+ const Array2D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR2(BuilderT* builder, const Array2D<NativeT>& expected,
+ template <typename NativeT>
+ void ComputeAndCompareR2(XlaBuilder* builder,
+ const Array2D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
ErrorSpec error);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR3(BuilderT* builder, const Array3D<NativeT>& expected,
+ template <typename NativeT>
+ void ComputeAndCompareR3(XlaBuilder* builder,
+ const Array3D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR3(BuilderT* builder, const Array3D<NativeT>& expected,
+ template <typename NativeT>
+ void ComputeAndCompareR3(XlaBuilder* builder,
+ const Array3D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
ErrorSpec error);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR4(BuilderT* builder, const Array4D<NativeT>& expected,
+ template <typename NativeT>
+ void ComputeAndCompareR4(XlaBuilder* builder,
+ const Array4D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- template <typename NativeT, typename BuilderT>
- void ComputeAndCompareR4(BuilderT* builder, const Array4D<NativeT>& expected,
+ template <typename NativeT>
+ void ComputeAndCompareR4(XlaBuilder* builder,
+ const Array4D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
ErrorSpec error);
// Build and run the computation and compare the result with the given
// literal. shape_with_layout indicates the result layout to request when
// calling Execute.
- template <typename BuilderT>
void ComputeAndCompareLiteral(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_layout = nullptr);
- template <typename BuilderT>
void ComputeAndCompareLiteral(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout = nullptr);
// ComputeAndCompare variant which returns an error status.
- template <typename BuilderT>
tensorflow::Status ComputeAndCompareLiteralWithStatus(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* shape_with_layout = nullptr);
- template <typename BuilderT>
tensorflow::Status ComputeAndCompareLiteralWithStatus(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
const Shape* shape_with_layout = nullptr);
// Convenience method for running a built computation, transferring the
// result, and comparing it to the expected tuple literal.
- template <typename BuilderT>
void ComputeAndCompareTuple(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
- template <typename BuilderT>
void ComputeAndCompareTuple(
- BuilderT* builder, const Literal& expected,
+ XlaBuilder* builder, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
// Convenience method for running a built computation and comparing the result
- // with the HloEvaluator.
- void ComputeAndCompare(ComputationBuilder* builder,
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments);
- void ComputeAndCompare(ComputationBuilder* builder,
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments,
- ErrorSpec error);
-
- // Convenience method for running a built computation and comparing the result
// with the reference result.
void ComputeAndCompare(XlaBuilder* builder,
tensorflow::gtl::ArraySlice<Literal> arguments);
// Create scalar operations for use in reductions.
XlaComputation CreateScalarRelu();
XlaComputation CreateScalarMax();
- Computation CreateScalarReluSensitivity();
+ XlaComputation CreateScalarReluSensitivity();
// Special case convenience functions for creating filled arrays.
// server, then stores into "data_handle" the global handle for that
// parameter. When the use_bfloat16 flag is set but the literal has F32
// elements, the literal will be converted to BF16 before being transferred.
- template <typename BuilderT, typename HandleT>
std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
int64 parameter_number, const Literal& literal, const string& name,
- BuilderT* builder, HandleT* data_handle);
+ XlaBuilder* builder, XlaOp* data_handle);
// As above, but the caller can specify the device that the literal is
// transferred to. If device_handle is nullptr, the literal will be
// transferred to the default device.
- template <typename BuilderT, typename HandleT>
std::unique_ptr<GlobalData> CreateParameterAndTransferLiteral(
int64 parameter_number, const Literal& literal, const string& name,
- const DeviceHandle* device_handle, BuilderT* builder,
- HandleT* data_handle);
+ const DeviceHandle* device_handle, XlaBuilder* builder,
+ XlaOp* data_handle);
// Creates a parameter instruction and sets the value that will be passed to
// the computation as specified. This function must be used for all parameters
// or none and no parameters must be passed when invoking the computation if
// using this mechanism. If using this mechanism, then each parameter must be
// set exactly once. The first added parameter gets index 0, then 1 and so on.
- ComputationDataHandle AddParam(const Literal& argument,
- ComputationBuilder* builder);
XlaOp AddParam(const Literal& argument, XlaBuilder* builder);
template <class T>
- ComputationDataHandle AddParam(const Array<T>& argument,
- ComputationBuilder* builder) {
- return AddParam(*Literal::CreateFromArray(argument), builder);
- }
- template <class T>
XlaOp AddParam(const Array<T>& argument, XlaBuilder* builder) {
return AddParam(*Literal::CreateFromArray(argument), builder);
}
// Creates a constant instruction with the given literal. When the
// use_bfloat16 flag is set but the literal has F32 elements, the elements
// will be converted to BF16s.
- ComputationDataHandle CreateConstantFromLiteral(const Literal& literal,
- ComputationBuilder* builder);
XlaOp CreateConstantFromLiteral(const Literal& literal, XlaBuilder* builder);
// Creates a constant instruction with the given array. When the use_bfloat16
// flag is set but the array has float elements, the elements will be
// converted to bfloat16s.
- template <typename NativeT>
- ComputationDataHandle CreateConstantFromArray(const Array<NativeT>& array,
- ComputationBuilder* builder) {
- return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder);
- }
template <typename NativeT>
XlaOp CreateConstantFromArray(const Array<NativeT>& array,
// Same as CreateConstantFromArray, but for scalars.
template <typename NativeT>
- ComputationDataHandle CreateConstantFromScalar(NativeT value,
- ComputationBuilder* builder) {
- return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value),
- builder);
- }
-
- template <typename NativeT>
XlaOp CreateConstantFromScalar(NativeT value, XlaBuilder* builder) {
return CreateConstantFromLiteral(*Literal::CreateR0<NativeT>(value),
builder);
//
// When the use_bfloat16 flag is set but NativeT is float, the data will be
// converted to bfloat16.
- template <typename NativeT, typename BuilderT, typename HandleT>
+ template <typename NativeT>
std::unique_ptr<GlobalData> CreateR0Parameter(NativeT value,
int64 parameter_number,
const string& name,
- BuilderT* builder,
- HandleT* data_handle);
+ XlaBuilder* builder,
+ XlaOp* data_handle);
// Creates a parameter instruction that wraps the given values and then stores
// into "data_handle" the global handle for that parameter.
//
// When the use_bfloat16 flag is set but NativeT is float, the data will be
// converted to bfloat16.
- template <typename NativeT, typename BuilderT, typename HandleT>
+ template <typename NativeT>
std::unique_ptr<GlobalData> CreateR1Parameter(
tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
- const string& name, BuilderT* builder, HandleT* data_handle);
+ const string& name, XlaBuilder* builder, XlaOp* data_handle);
// Creates a parameter instruction that wraps the given constant array
// "array_2d" and then stores to "data_handle" the global handle for that
//
// When the use_bfloat16 flag is set but NativeT is float, the data will be
// converted to bfloat16.
- template <typename NativeT, typename BuilderT, typename HandleT>
+ template <typename NativeT>
std::unique_ptr<GlobalData> CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
- const string& name, BuilderT* builder, HandleT* data_handle);
+ const string& name, XlaBuilder* builder, XlaOp* data_handle);
// Creates a parameter instruction that wraps the given constant array
// "array_3d" and then stores to "data_handle" the global handle for that
//
// When the use_bfloat16 flag is set but NativeT is float, the data will be
// converted to bfloat16.
- template <typename NativeT, typename BuilderT, typename HandleT>
+ template <typename NativeT>
std::unique_ptr<GlobalData> CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
- const string& name, BuilderT* builder, HandleT* data_handle);
+ const string& name, XlaBuilder* builder, XlaOp* data_handle);
// Getter and setter for the use_bfloat16 flag, which indicates whether to run
// tests with all float-type input/output converted to bfloat16.
ExecutionOptions execution_options_;
private:
- // Build and run the computation with all permutations of output layouts.
- tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts(
- const xla::Computation& computation, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const std::function<void(const Literal& actual,
- const string& error_message)>& verify_output);
- // Build and run the computation with all permutations of layouts of all input
- // arguments.
- tensorflow::Status ComputeAndCompareLiteralWithAllInputLayouts(
- const xla::Computation& computation, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments,
- const std::function<void(const Literal& actual,
- const string& error_message)>& verify_output,
- const Shape* output_with_layout = nullptr);
-
tensorflow::Status ComputeAndCompareLiteralWithAllOutputLayouts(
const xla::XlaComputation& computation, const Literal& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
const Shape* output_with_layout = nullptr);
// Executes the computation and calculates the expected reference value using
- // the HloEvaluator. Returns two literals in the order of (expected, actual).
- StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
- ComputeValueAndReference(ComputationBuilder* builder,
- const ComputationDataHandle& operand,
- tensorflow::gtl::ArraySlice<Literal> arguments);
-
- // Executes the computation and calculates the expected reference value using
// the reference client. Returns two literals in the order of (expected,
// actual).
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
std::vector<std::unique_ptr<GlobalData>> arguments_;
};
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
- BuilderT* builder, NativeT expected,
+ XlaBuilder* builder, NativeT expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal::CreateR0<NativeT>(expected);
arguments);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR0(
- BuilderT* builder, NativeT expected,
+ XlaBuilder* builder, NativeT expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
arguments, error);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
- BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+ XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal::CreateR1<NativeT>(expected);
arguments);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR1(
- BuilderT* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+ XlaBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
arguments, error);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
- BuilderT* builder, const Array2D<NativeT>& expected,
+ XlaBuilder* builder, const Array2D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal::CreateR2FromArray2D<NativeT>(expected);
arguments);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR2(
- BuilderT* builder, const Array2D<NativeT>& expected,
+ XlaBuilder* builder, const Array2D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
arguments, error);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
- BuilderT* builder, const Array3D<NativeT>& expected,
+ XlaBuilder* builder, const Array3D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal::CreateR3FromArray3D<NativeT>(expected);
arguments);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR3(
- BuilderT* builder, const Array3D<NativeT>& expected,
+ XlaBuilder* builder, const Array3D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
arguments, error);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
- BuilderT* builder, const Array4D<NativeT>& expected,
+ XlaBuilder* builder, const Array4D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
std::unique_ptr<Literal> expected_literal =
Literal::CreateR4FromArray4D<NativeT>(expected);
arguments);
}
-template <typename NativeT, typename BuilderT>
+template <typename NativeT>
void ClientLibraryTestBase::ComputeAndCompareR4(
- BuilderT* builder, const Array4D<NativeT>& expected,
+ XlaBuilder* builder, const Array4D<NativeT>& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
static_assert(std::is_same<NativeT, float>::value ||
std::is_same<NativeT, double>::value ||
arguments, error);
}
-template <typename NativeT, typename BuilderT, typename HandleT>
+template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
NativeT value, int64 parameter_number, const string& name,
- BuilderT* builder, HandleT* data_handle) {
+ XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR0(value);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
return data;
}
-template <typename NativeT, typename BuilderT, typename HandleT>
+template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
- const string& name, BuilderT* builder, HandleT* data_handle) {
+ const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR1(values);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
return data;
}
-template <typename NativeT, typename BuilderT, typename HandleT>
+template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
const Array2D<NativeT>& array_2d, int64 parameter_number,
- const string& name, BuilderT* builder, HandleT* data_handle) {
+ const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
return data;
}
-template <typename NativeT, typename BuilderT, typename HandleT>
+template <typename NativeT>
std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
const Array3D<NativeT>& array_3d, int64 parameter_number,
- const string& name, BuilderT* builder, HandleT* data_handle) {
+ const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
return result;
}
-template <typename BuilderT, typename HandleT>
std::unique_ptr<GlobalData>
ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
const Literal& literal,
const string& name,
- BuilderT* builder,
- HandleT* data_handle) {
+ XlaBuilder* builder,
+ XlaOp* data_handle) {
return CreateParameterAndTransferLiteral(parameter_number, literal, name,
nullptr, builder, data_handle);
}
-template <typename BuilderT, typename HandleT>
std::unique_ptr<GlobalData>
ClientLibraryTestBase::CreateParameterAndTransferLiteral(
int64 parameter_number, const Literal& literal, const string& name,
- const DeviceHandle* device_handle, BuilderT* builder,
- HandleT* data_handle) {
+ const DeviceHandle* device_handle, XlaBuilder* builder,
+ XlaOp* data_handle) {
const Literal* param_literal = &literal;
std::unique_ptr<Literal> converted_literal;
if (use_bfloat16_) {