[XLA] Cleanup client_library_test_base: move definition of CreateParameterAndTransfer...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 4 May 2018 18:40:01 +0000 (11:40 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 20:34:49 +0000 (13:34 -0700)
PiperOrigin-RevId: 195446864

tensorflow/compiler/xla/tests/client_library_test_base.cc
tensorflow/compiler/xla/tests/client_library_test_base.h

index c09e7ea..41f9a5f 100644 (file)
@@ -565,4 +565,33 @@ XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
       use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
 }
 
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
+                                                         const Literal& literal,
+                                                         const string& name,
+                                                         XlaBuilder* builder,
+                                                         XlaOp* data_handle) {
+  return CreateParameterAndTransferLiteral(parameter_number, literal, name,
+                                           nullptr, builder, data_handle);
+}
+
+std::unique_ptr<GlobalData>
+ClientLibraryTestBase::CreateParameterAndTransferLiteral(
+    int64 parameter_number, const Literal& literal, const string& name,
+    const DeviceHandle* device_handle, XlaBuilder* builder,
+    XlaOp* data_handle) {
+  const Literal* param_literal = &literal;
+  std::unique_ptr<Literal> converted_literal;
+  if (use_bfloat16_) {
+    converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
+    param_literal = converted_literal.get();
+  }
+  std::unique_ptr<GlobalData> data =
+      client_->TransferToServer(*param_literal, device_handle)
+          .ConsumeValueOrDie();
+  *data_handle =
+      builder->Parameter(parameter_number, param_literal->shape(), name);
+  return data;
+}
+
 }  // namespace xla
index e58979a..16e838e 100644 (file)
@@ -616,35 +616,6 @@ std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
   return result;
 }
 
-std::unique_ptr<GlobalData>
-ClientLibraryTestBase::CreateParameterAndTransferLiteral(int64 parameter_number,
-                                                         const Literal& literal,
-                                                         const string& name,
-                                                         XlaBuilder* builder,
-                                                         XlaOp* data_handle) {
-  return CreateParameterAndTransferLiteral(parameter_number, literal, name,
-                                           nullptr, builder, data_handle);
-}
-
-std::unique_ptr<GlobalData>
-ClientLibraryTestBase::CreateParameterAndTransferLiteral(
-    int64 parameter_number, const Literal& literal, const string& name,
-    const DeviceHandle* device_handle, XlaBuilder* builder,
-    XlaOp* data_handle) {
-  const Literal* param_literal = &literal;
-  std::unique_ptr<Literal> converted_literal;
-  if (use_bfloat16_) {
-    converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
-    param_literal = converted_literal.get();
-  }
-  std::unique_ptr<GlobalData> data =
-      client_->TransferToServer(*param_literal, device_handle)
-          .ConsumeValueOrDie();
-  *data_handle =
-      builder->Parameter(parameter_number, param_literal->shape(), name);
-  return data;
-}
-
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_