use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
}
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
+ const Literal& literal,
+ const string& name,
+ XlaBuilder* builder,
+ XlaOp* data_handle) {
+ return CreateParameterAndTransferLiteral(parameter_number, literal, name,
+ nullptr, builder, data_handle);
+}
+
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(
+ int64 parameter_number, const Literal& literal, const string& name,
+ const DeviceHandle* device_handle, XlaBuilder* builder,
+ XlaOp* data_handle) {
+ const Literal* param_literal = &literal;
+ std::unique_ptr<Literal> converted_literal;
+ if (use_bfloat16_) {
+ converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
+ param_literal = converted_literal.get();
+ }
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*param_literal, device_handle)
+ .ConsumeValueOrDie();
+ *data_handle =
+ builder->Parameter(parameter_number, param_literal->shape(), name);
+ return data;
+}
+
} // namespace xla
return result;
}
-std::unique_ptr<GlobalData>
-ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
- const Literal& literal,
- const string& name,
- XlaBuilder* builder,
- XlaOp* data_handle) {
- return CreateParameterAndTransferLiteral(parameter_number, literal, name,
- nullptr, builder, data_handle);
-}
-
-std::unique_ptr<GlobalData>
-ClientLibraryTestBase::CreateParameterAndTransferLiteral(
- int64 parameter_number, const Literal& literal, const string& name,
- const DeviceHandle* device_handle, XlaBuilder* builder,
- XlaOp* data_handle) {
- const Literal* param_literal = &literal;
- std::unique_ptr<Literal> converted_literal;
- if (use_bfloat16_) {
- converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
- param_literal = converted_literal.get();
- }
- std::unique_ptr<GlobalData> data =
- client_->TransferToServer(*param_literal, device_handle)
- .ConsumeValueOrDie();
- *data_handle =
- builder->Parameter(parameter_number, param_literal->shape(), name);
- return data;
-}
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_