[XLA] Break out literal comparisons from testonly target.
authorChris Leary <leary@google.com>
Fri, 11 May 2018 03:10:34 +0000 (20:10 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 03:13:33 +0000 (20:13 -0700)
Moves methods from LiteralTestUtil::* to Literal::* where they have nothing
to do with test infrastructure.

Pares down the "void" variants of the LiteralTestUtil methods and consolidates
to the version that return success/failure such that the values can be
EXPECT_TRUE / ASSERT_TRUE asserted in the caller test cases.

This way the literal comparison functionality can be used from cc_libraries
that are not test only / cc_binary.

PiperOrigin-RevId: 196209410

32 files changed:
tensorflow/compiler/tf2xla/xla_compiler_test.cc
tensorflow/compiler/xla/BUILD
tensorflow/compiler/xla/literal_comparison.cc [new file with mode: 0644]
tensorflow/compiler/xla/literal_comparison.h [new file with mode: 0644]
tensorflow/compiler/xla/literal_util.cc
tensorflow/compiler/xla/literal_util.h
tensorflow/compiler/xla/rpc/grpc_client_test.cc
tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
tensorflow/compiler/xla/service/hlo_constant_folding_test.cc
tensorflow/compiler/xla/service/hlo_cse_test.cc
tensorflow/compiler/xla/service/hlo_evaluator_test.cc
tensorflow/compiler/xla/service/inliner_test.cc
tensorflow/compiler/xla/tests/BUILD
tensorflow/compiler/xla/tests/broadcast_test.cc
tensorflow/compiler/xla/tests/client_library_test_base.cc
tensorflow/compiler/xla/tests/client_library_test_base.h
tensorflow/compiler/xla/tests/client_test.cc
tensorflow/compiler/xla/tests/compilation_cache_test.cc
tensorflow/compiler/xla/tests/compute_constant_test.cc
tensorflow/compiler/xla/tests/copy_test.cc
tensorflow/compiler/xla/tests/fusion_test.cc
tensorflow/compiler/xla/tests/gather_operation_test.cc
tensorflow/compiler/xla/tests/literal_test_util.cc
tensorflow/compiler/xla/tests/literal_test_util.h
tensorflow/compiler/xla/tests/literal_test_util_test.cc
tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
tensorflow/compiler/xla/tests/prng_test.cc
tensorflow/compiler/xla/tests/reshape_test.cc
tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
tensorflow/compiler/xla/tests/scalar_computations_test.cc
tensorflow/compiler/xla/tests/transfer_manager_test.cc

index 6b8918b..4382ffe 100644 (file)
@@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, Simple) {
       xla::Literal::CreateR1<int32>({4, 143});
   std::unique_ptr<xla::Literal> expected_literal =
       xla::Literal::MakeTuple({expected0.get()});
-  xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 }
 
 TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
@@ -320,7 +320,8 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
         xla::Literal::CreateR1<int32>({-7, -42});
     std::unique_ptr<xla::Literal> expected_literal =
         xla::Literal::MakeTuple({expected0.get()});
-    xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+    EXPECT_TRUE(
+        xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
   }
 
   {
@@ -355,7 +356,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
         xla::Literal::CreateR1<int32>({-7, -42});
     std::unique_ptr<xla::Literal> expected =
         xla::Literal::MakeTuple({expected0.get(), expected1.get()});
-    xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal);
+    EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
   }
 }
 
@@ -523,7 +524,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
       {output_base.get(), output_grad1.get(), output_grad2.get()});
   std::unique_ptr<xla::Literal> expected_literal =
       xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
-  xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 }
 
 // Tests compilation and execution of a graph that adds two tensors.
@@ -746,7 +747,7 @@ TEST_F(XlaCompilerTest, Variables) {
       xla::Literal::CreateR1<int32>({4, 143});
   std::unique_ptr<xla::Literal> expected_literal =
       xla::Literal::MakeTuple({expected0.get(), expected1.get()});
-  xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 }
 
 // Tests a simple graph that reads and writes a variable, with a
@@ -811,7 +812,7 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
       xla::Literal::CreateR1<int32>({26, 66, 34, 401});
   std::unique_ptr<xla::Literal> expected_literal =
       xla::Literal::MakeTuple({expected0.get(), expected1.get()});
-  xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
 }
 
 }  // namespace
index dbf14f3..729480e 100644 (file)
@@ -331,6 +331,17 @@ tf_cc_test(
 )
 
 cc_library(
+    name = "literal_comparison",
+    srcs = ["literal_comparison.cc"],
+    hdrs = ["literal_comparison.h"],
+    deps = [
+        ":literal_util",
+        ":util",
+        "//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
     name = "metric_table_report",
     srcs = ["metric_table_report.cc"],
     hdrs = ["metric_table_report.h"],
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
new file mode 100644 (file)
index 0000000..df3f5af
--- /dev/null
@@ -0,0 +1,226 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal_comparison.h"
+
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+using tensorflow::strings::StrCat;
+
+namespace xla {
+namespace literal_comparison {
+namespace {
+
+// Helper function for comparing a floating point type, FloatT, bitwise equal
+// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
+// -- on miscompare, a nice error message is given in the AssertionFailure.
+template <typename FloatT, typename UnsignedT>
+Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
+  auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
+  auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
+  auto lhs_double = static_cast<double>(lhs);
+  auto rhs_double = static_cast<double>(rhs);
+  if (ulhs != urhs) {
+    return InvalidArgument(
+        "floating values are not bitwise-equal; and equality testing "
+        "was requested: %s=%g=%a vs %s=%g=%a",
+        StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
+        StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double);
+  }
+  return Status::OK();
+}
+
+// Templated comparator that specializes for float equality comparison with the
+// bitwise helper above (this is the un-specialized fallback, to just use the
+// default gunit implementation).
+template <typename NativeT>
+Status CompareEqual(NativeT lhs, NativeT rhs) {
+  if (lhs == rhs) {
+    return Status::OK();
+  }
+  return InvalidArgument("Expected equality of these values:\n  %s\n  %s",
+                         StrCat(lhs).c_str(), StrCat(rhs).c_str());
+}
+
+// Specializations for floating types that do bitwise comparisons when equality
+// comparison is requested.
+template <>
+Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
+  return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
+}
+template <>
+Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs) {
+  return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
+}
+template <>
+Status CompareEqual<float>(float lhs, float rhs) {
+  return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
+}
+template <>
+Status CompareEqual<double>(double lhs, double rhs) {
+  return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
+}
+template <>
+Status CompareEqual<complex64>(complex64 lhs, complex64 rhs) {
+  auto res = CompareEqual<float>(lhs.real(), rhs.real());
+  if (!res.ok()) {
+    return res;
+  }
+  return CompareEqual<float>(lhs.imag(), rhs.imag());
+}
+
+// A recursive function which iterates through every index of expected and
+// actual literal and compares their values elementwise. Returns true if all
+// elements are equal.
+template <typename NativeT>
+Status Equal(LiteralSlice expected, LiteralSlice actual,
+             tensorflow::gtl::MutableArraySlice<int64> multi_index,
+             int64 dimension) {
+  if (dimension == expected.shape().dimensions_size()) {
+    NativeT expected_value = expected.Get<NativeT>(multi_index);
+    NativeT actual_value = actual.Get<NativeT>(multi_index);
+    return CompareEqual<NativeT>(expected_value, actual_value);
+  }
+
+  Status result;
+  for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
+    multi_index[dimension] = i;
+    result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1));
+  }
+  return result;
+}
+
+}  // namespace
+
+Status EqualShapes(const Shape& expected, const Shape& actual) {
+  if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
+    return InvalidArgument("tupleness-mismatch! want: %s got %s",
+                           ShapeUtil::HumanString(expected).c_str(),
+                           ShapeUtil::HumanString(actual).c_str());
+  }
+  if (ShapeUtil::IsTuple(expected)) {
+    if (ShapeUtil::TupleElementCount(expected) !=
+        ShapeUtil::TupleElementCount(actual)) {
+      return InvalidArgument(
+          "want tuple element count: %lld got tuple element count: %lld",
+          ShapeUtil::TupleElementCount(expected),
+          ShapeUtil::TupleElementCount(actual));
+    }
+    for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
+      Status result =
+          EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
+      if (!result.ok()) {
+        return AppendStatus(result, StrCat("mismatch in tuple index", i));
+      }
+    }
+  } else {
+    if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
+      return InvalidArgument("want rank of %s got rank of %s",
+                             ShapeUtil::HumanString(expected).c_str(),
+                             ShapeUtil::HumanString(actual).c_str());
+    }
+    if (expected.element_type() != actual.element_type()) {
+      return InvalidArgument(
+          "mismatch in primitive type %s vs %s",
+          PrimitiveType_Name(expected.element_type()).c_str(),
+          PrimitiveType_Name(actual.element_type()).c_str());
+    }
+    if (expected.dimensions_size() != actual.dimensions_size()) {
+      return InvalidArgument("want dimensions_size %d got dimensions_size %d",
+                             expected.dimensions_size(),
+                             actual.dimensions_size());
+    }
+    for (int i = 0; i < expected.dimensions_size(); ++i) {
+      if (expected.dimensions(i) != actual.dimensions(i)) {
+        return InvalidArgument(
+            "mismatch in dimension #%d expected: %s actual: %s", i,
+            ShapeUtil::HumanString(expected).c_str(),
+            ShapeUtil::HumanString(actual).c_str());
+      }
+    }
+  }
+  return Status::OK();
+}
+
+Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
+  VLOG(1) << "expected:";
+  XLA_VLOG_LINES(1, expected.ToString());
+  VLOG(1) << "actual:";
+  XLA_VLOG_LINES(1, actual.ToString());
+
+  TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
+  std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+  Status result;
+  switch (expected.shape().element_type()) {
+    case PRED:
+      result = Equal<bool>(expected, actual, &multi_index, 0);
+      break;
+    case U8:
+      result = Equal<uint8>(expected, actual, &multi_index, 0);
+      break;
+    case S32:
+      result = Equal<int32>(expected, actual, &multi_index, 0);
+      break;
+    case S64:
+      result = Equal<int64>(expected, actual, &multi_index, 0);
+      break;
+    case U32:
+      result = Equal<uint32>(expected, actual, &multi_index, 0);
+      break;
+    case U64:
+      result = Equal<uint64>(expected, actual, &multi_index, 0);
+      break;
+    case BF16:
+      result = Equal<bfloat16>(expected, actual, &multi_index, 0);
+      break;
+    case F16:
+      result = Equal<half>(expected, actual, &multi_index, 0);
+      break;
+    case F32:
+      result = Equal<float>(expected, actual, &multi_index, 0);
+      break;
+    case F64:
+      result = Equal<double>(expected, actual, &multi_index, 0);
+      break;
+    case C64:
+      result = Equal<complex64>(expected, actual, &multi_index, 0);
+      break;
+    case TUPLE: {
+      for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+        result.Update(
+            Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})));
+      }
+      break;
+    }
+    default:
+      LOG(FATAL)
+          << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
+          << PrimitiveType_Name(expected.shape().element_type());
+  }
+
+  if (result.ok()) {
+    return Status::OK();
+  }
+
+  return AppendStatus(result,
+                      tensorflow::strings::Printf("expected: %s\nactual:   %s",
+                                                  expected.ToString().c_str(),
+                                                  actual.ToString().c_str()));
+}
+
+}  // namespace literal_comparison
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/literal_comparison.h b/tensorflow/compiler/xla/literal_comparison.h
new file mode 100644 (file)
index 0000000..e667405
--- /dev/null
@@ -0,0 +1,40 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Library for comparing literals without taking a dependency on testing
+// libraries.
+
+#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
+#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+namespace literal_comparison {
+
+// Returns ok if the given shapes have the same rank, dimension sizes, and
+// primitive types.
+Status EqualShapes(const Shape& expected, const Shape& actual);
+
+// Returns ok if the expected and actual literals are (bitwise) equal for all
+// elements in the literal. Also, asserts that the rank, dimensions sizes, and
+// primitive type are equal.
+Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
+
+}  // namespace literal_comparison
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
index e9b0e11..82a2bca 100644 (file)
@@ -62,6 +62,45 @@ void ConvertEndianShort(char* bytes, int64 size) {
   }
 }
 
+// Return a literal with all arrays of type FromNativeT converted to type
+// ToNativeT in the given literal.
+template <typename FromNativeT, typename ToNativeT>
+std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+  // First construct shape of the result.
+  Shape result_shape(literal.shape());
+  ShapeUtil::ForEachMutableSubshape(
+      &result_shape, [](Shape* subshape, const ShapeIndex&) {
+        if (subshape->element_type() ==
+            primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+          subshape->set_element_type(
+              primitive_util::NativeToPrimitiveType<ToNativeT>());
+        }
+      });
+  auto result = MakeUnique<Literal>(result_shape);
+
+  // Then copy over the data from 'literal' converting FromNativeT values to
+  // ToNativeT values as necessary.
+  ShapeUtil::ForEachSubshape(
+      literal.shape(),
+      [&](const Shape& subshape, const ShapeIndex& shape_index) {
+        if (ShapeUtil::IsArray(subshape)) {
+          if (subshape.element_type() ==
+              primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+            auto src = literal.data<FromNativeT>(shape_index);
+            auto dest = result->data<ToNativeT>(shape_index);
+            for (int64 i = 0; i < src.size(); ++i) {
+              dest[i] = static_cast<ToNativeT>(src[i]);
+            }
+          } else {
+            TF_CHECK_OK(result->CopyFrom(literal,
+                                         /*dest_shape_index=*/shape_index,
+                                         /*src_shape_index=*/shape_index));
+          }
+        }
+      });
+  return result;
+}
+
 }  // namespace
 
 LiteralBase::~LiteralBase() {}
@@ -195,6 +234,16 @@ SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) {
   return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
 }
 
+/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32(
+    const LiteralSlice& bf16_literal) {
+  return ConvertType<bfloat16, float>(bf16_literal);
+}
+
+/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16(
+    const LiteralSlice& f32_literal) {
+  return ConvertType<float, bfloat16>(f32_literal);
+}
+
 template <typename NativeT>
 Status Literal::CopySliceFromInternal(
     const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
@@ -788,6 +837,78 @@ StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
   return std::move(output);
 }
 
+/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice(
+    tensorflow::gtl::ArraySlice<int64> new_dimensions,
+    tensorflow::gtl::ArraySlice<int64> minor_to_major,
+    const LiteralSlice& literal) {
+  int64 new_num_elements = 1;
+  for (int64 i = 0; i < new_dimensions.size(); ++i) {
+    new_num_elements *= new_dimensions[i];
+  }
+  CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
+  CHECK_EQ(new_dimensions.size(), minor_to_major.size());
+
+  auto new_literal = MakeUnique<Literal>(
+      ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
+
+  // Create a new shape with the given minor-to-major layout. This shape is used
+  // solely for converting linear address to multi-dimensional addresses when
+  // writing elements to the new literal.
+  Shape shape_with_layout = new_literal->shape();
+  *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+
+  // Copy data into new literal, element-by-element.
+  for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
+    std::vector<int64> from_multi_index =
+        IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
+    std::vector<int64> to_multi_index =
+        IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
+    switch (literal.shape().element_type()) {
+      case PRED:
+        new_literal->Set<bool>(to_multi_index,
+                               literal.Get<bool>(from_multi_index));
+        break;
+      case U8:
+        new_literal->Set<uint8>(to_multi_index,
+                                literal.Get<uint8>(from_multi_index));
+        break;
+      case U32:
+        new_literal->Set<uint32>(to_multi_index,
+                                 literal.Get<uint32>(from_multi_index));
+        break;
+      case S32:
+        new_literal->Set<int32>(to_multi_index,
+                                literal.Get<int32>(from_multi_index));
+        break;
+      case U64:
+        new_literal->Set<uint64>(to_multi_index,
+                                 literal.Get<uint64>(from_multi_index));
+        break;
+      case S64:
+        new_literal->Set<int64>(to_multi_index,
+                                literal.Get<int64>(from_multi_index));
+        break;
+      case F32:
+        new_literal->Set<float>(to_multi_index,
+                                literal.Get<float>(from_multi_index));
+        break;
+      case F64:
+        new_literal->Set<double>(to_multi_index,
+                                 literal.Get<double>(from_multi_index));
+        break;
+      case C64:
+        new_literal->Set<complex64>(to_multi_index,
+                                    literal.Get<complex64>(from_multi_index));
+        break;
+      default:
+        LOG(FATAL) << "Unhandled primitive element type: "
+                   << PrimitiveType_Name(literal.shape().element_type());
+    }
+  }
+
+  return new_literal;
+}
+
 std::unique_ptr<Literal> LiteralBase::Transpose(
     tensorflow::gtl::ArraySlice<int64> permutation) const {
   CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
@@ -2123,6 +2244,11 @@ StatusOr<std::unique_ptr<Literal>> Literal::CreateFromProto(
   return std::move(literal);
 }
 
+/* static */ string Literal::MultiIndexAsString(
+    tensorflow::gtl::ArraySlice<int64> multi_index) {
+  return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
+}
+
 const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
   return piece(shape_index).untyped_data();
 }
index 30442af..8d51aa3 100644 (file)
@@ -920,9 +920,66 @@ class Literal : public LiteralBase {
       PrimitiveType primitive_type,
       tensorflow::gtl::ArraySlice<int64> dimensions);
 
+  // If the given literal's data type is bfloat16, converts it to a float
+  // literal; otherwise, returns a copy of it. If the literal is a tuple,
+  // recursively converts its elements.
+  static std::unique_ptr<Literal> ConvertBF16ToF32(
+      const LiteralSlice& bf16_literal);
+
+  // If the given literal's data type is float, converts it to a bfloat16
+  // literal; otherwise, returns a copy of it. If the literal is a tuple,
+  // recursively converts its elements.
+  static std::unique_ptr<Literal> ConvertF32ToBF16(
+      const LiteralSlice& f32_literal);
+
+  // Creates a literal with a new shape with the given new dimensions using the
+  // data in the given input literal. For reshaping purposes the (flat) data
+  // buffer of the input literal is assumed to have the given minor_to_major
+  // layout order.
+  static std::unique_ptr<Literal> ReshapeSlice(
+      tensorflow::gtl::ArraySlice<int64> new_dimensions,
+      tensorflow::gtl::ArraySlice<int64> minor_to_major,
+      const LiteralSlice& literal);
+
+  // Creates a literal with the supplied shape, and uses the provided value
+  // generator to populate the literal's values.
+  // Returns the new literal object, or an error Status if failed.
+  template <
+      PrimitiveType type,
+      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+      const Shape& shape,
+      const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
+
+  // Creates a literal with the supplied shape, and initializes the literal
+  // values using a normal distribution with given mean and stddev standard
+  // deviation, and using the engine as entropy generator.
+  // Returns the new literal object, or an error Status if failed.
+  template <
+      PrimitiveType type, typename E,
+      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+      const Shape& shape, E* engine, T mean, T stddev);
+
+  // Creates a literal with the supplied shape, and initializes the literal
+  // values using a normal distribution with given mean and stddev standard
+  // deviation.
+  // Returns the new literal object, or an error Status if failed.
+  template <
+      PrimitiveType type,
+      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+      const Shape& shape, T mean, T stddev);
+
   //
   // End of factory methods.
 
+  // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
+  // be returned for a 2-dimensional index with dimension 0 index equal to 7,
+  // dimension 1 equal to 8.
+  static string MultiIndexAsString(
+      tensorflow::gtl::ArraySlice<int64> multi_index);
+
  protected:
   // Recursively sets the subshapes and buffers of all subpieces rooted at
   // 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
@@ -1558,6 +1615,38 @@ std::unique_ptr<Literal> LiteralBase::Replicate(int64 times) const {
   return literal;
 }
 
+template <PrimitiveType type, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+    const Shape& shape,
+    const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
+  using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
+  TF_RET_CHECK(shape.element_type() == type);
+  std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
+  TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
+      [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+        return generator(indexes);
+      }));
+  return std::move(literal);
+}
+
+template <PrimitiveType type, typename E, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+    const Shape& shape, E* engine, T mean, T stddev) {
+  using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
+  std::normal_distribution<NativeT> generator(mean, stddev);
+  return CreateRandomLiteral<type, NativeT>(
+      shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
+        return generator(*engine);
+      });
+}
+
+template <PrimitiveType type, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+    const Shape& shape, T mean, T stddev) {
+  std::minstd_rand0 engine;
+  return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
+}
+
 }  // namespace xla
 
 #endif  // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
index 10997c0..313f11a 100644 (file)
@@ -101,8 +101,8 @@ TEST_F(GRPCClientTestBase, AxpyTenValues) {
   TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
   TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
                                                    computation, {}, nullptr));
-  LiteralTestUtil::ExpectNear(*expected_literal, *result_literal,
-                              ErrorSpec(0.0001));
+  EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
+                                    ErrorSpec(0.0001)));
 }
 
 }  // namespace
index 313910a..5e1499e 100644 (file)
@@ -149,12 +149,12 @@ TEST_F(BFloat16PropagationTest, ConvertConstantLiteral) {
   EXPECT_TRUE(OutputsBF16(dot->operand(1)));
   EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
   EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       dot->operand(0)->literal(),
-      *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)));
-  LiteralTestUtil::ExpectEqual(
+      *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))));
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       dot->operand(1)->literal(),
-      *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)));
+      *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))));
 }
 
 // Tests that BF16 can be propagated through nested tuples.
index 7b552ee..5d05ccf 100644 (file)
@@ -149,7 +149,7 @@ TEST_F(HloConstantFoldingTest, Slice) {
   const int64 slice_limits[] = {10, 8, 6, 5, 9};
   const int64 slice_strides[] = {1, 1, 1, 1, 1};
   TF_ASSERT_OK_AND_ASSIGN(auto literal,
-                          LiteralTestUtil::CreateRandomLiteral<F32>(
+                          Literal::CreateRandomLiteral<F32>(
                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
   HloInstruction* literal_instruction = builder.AddInstruction(
       HloInstruction::CreateConstant(std::move(literal)));
@@ -172,7 +172,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
   HloComputation::Builder builder(TestName());
   const int64 dimensions[] = {11, 8, 7, 5, 9};
   TF_ASSERT_OK_AND_ASSIGN(auto literal,
-                          LiteralTestUtil::CreateRandomLiteral<F32>(
+                          Literal::CreateRandomLiteral<F32>(
                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
   auto literal_clone = literal->Literal::CloneToUnique();
   HloInstruction* literal_instruction = builder.AddInstruction(
index df8853f..a04b4f4 100644 (file)
@@ -72,7 +72,7 @@ TEST_F(HloCseTest, CombineTwoConstants) {
 
   auto result = ExecuteAndTransfer(std::move(module), {});
   auto expected = Literal::CreateR0<float>(84.0);
-  LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
 }
 
 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
@@ -104,7 +104,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
 
   auto result = ExecuteAndTransfer(std::move(module), {});
   auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
-  LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
 }
 
 TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
@@ -134,7 +134,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
 
   auto result = ExecuteAndTransfer(std::move(module), {});
   auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
-  LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
 }
 
 TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
index 8e9688c..ae5b5e0 100644 (file)
@@ -82,9 +82,9 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
     auto element_type = expected->shape().element_type();
     if (element_type == F32 || element_type == F64) {
       ErrorSpec error(aabs);
-      LiteralTestUtil::ExpectNear(*expected, *result, error);
+      EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
     } else {
-      LiteralTestUtil::ExpectEqual(*expected, *result);
+      EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
     }
   }
 
@@ -100,7 +100,7 @@ class HloEvaluatorTest : public ::testing::WithParamInterface<bool>,
 
     std::unique_ptr<Literal> result = Evaluate();
 
-    LiteralTestUtil::ExpectEqual(*expected, *result);
+    EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
   }
 
   bool use_bfloat16_;
@@ -129,7 +129,7 @@ TEST_P(HloEvaluatorTest, DoesClamp) {
 
   auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}});
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
@@ -150,7 +150,7 @@ TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
 
   auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 // Verifies that HloEvaluator evaluates a HLO instruction that performs select
@@ -175,7 +175,7 @@ TEST_P(HloEvaluatorTest, DoesSelect) {
 
   auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}});
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 // Verifies that HloEvaluator evaluates a HLO instruction that performs
@@ -307,7 +307,7 @@ TEST_P(HloEvaluatorTest, DoesTraverseInstructions) {
 
   auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 // Verifies Reshape operation is correctly evaluated.
@@ -315,7 +315,7 @@ TEST_P(HloEvaluatorTest, DoesReshape) {
   HloComputation::Builder b(TestName());
   const int64 dimensions[] = {11, 8, 7, 5, 9};
   TF_ASSERT_OK_AND_ASSIGN(auto literal,
-                          LiteralTestUtil::CreateRandomLiteral<F32>(
+                          Literal::CreateRandomLiteral<F32>(
                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
   auto literal_clone = literal->CloneToUnique();
   HloInstruction* literal_instruction =
@@ -351,7 +351,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcast) {
 
   std::unique_ptr<Literal> result = Evaluate({});
 
-  LiteralTestUtil::ExpectEqual(*result, *output_literal);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
 }
 
 TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
@@ -370,7 +370,7 @@ TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
 
   std::unique_ptr<Literal> result = Evaluate({});
 
-  LiteralTestUtil::ExpectEqual(*result, *output_literal);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
 }
 
 TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
@@ -392,7 +392,7 @@ TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
 
   auto expected =
       Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
@@ -413,7 +413,7 @@ TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
   std::unique_ptr<Literal> result = Evaluate();
 
   auto expected = Literal::CreateR1<int64>({100, 200});
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
@@ -432,7 +432,7 @@ TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
 
   std::unique_ptr<Literal> result = Evaluate();
 
-  LiteralTestUtil::ExpectEqual(*result, *expected);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
 }
 
 TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
@@ -452,7 +452,7 @@ TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
 
   std::unique_ptr<Literal> result = Evaluate();
 
-  LiteralTestUtil::ExpectEqual(*result, *expected);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
 }
 
 PaddingConfig CreatePaddingConfig(
@@ -490,7 +490,7 @@ TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) {
   auto expected = Literal::CreateR2<int32>(
       {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
@@ -525,7 +525,7 @@ TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
 
   auto expected = Literal::CreateR4FromArray4D<float>(*expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, NegativePadding2D) {
@@ -567,7 +567,7 @@ TEST_P(HloEvaluatorTest, NegativePadding2D) {
   (*expected_array)(0, 4) = 2.718f;
   auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
 
-  LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5));
+  EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5)));
 }
 
 TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
@@ -606,7 +606,7 @@ TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
   auto expected_array = MakeUnique<Array2D<float>>(0, 9);
   auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
@@ -651,7 +651,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
   // clang-format on
   auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
@@ -688,7 +688,7 @@ TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
 
   auto expected = Literal::CreateR1<float>({22.f, 28.f});
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
@@ -737,7 +737,7 @@ TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
   });
   auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, SimpleConv1D) {
@@ -785,7 +785,7 @@ TEST_P(HloEvaluatorTest, SimpleConv1D) {
   Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
   auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
@@ -847,7 +847,7 @@ TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
   // clang-format on
   auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
@@ -927,7 +927,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
   auto expected = Literal::CreateR4FromArray4D<float>(
       use_bfloat16_ ? expected_array_bf16 : expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
@@ -1004,7 +1004,7 @@ TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
   auto expected = Literal::CreateR4FromArray4D<float>(
       use_bfloat16_ ? expected_array_bf16 : expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
@@ -1067,7 +1067,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
   }));
   auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
@@ -1131,7 +1131,7 @@ TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
   }));
   auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest,
@@ -1203,7 +1203,7 @@ TEST_P(HloEvaluatorTest,
   }));
   auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
@@ -1319,7 +1319,7 @@ TEST_P(HloEvaluatorTest, ReduceAdd) {
 
   auto expected = Literal::CreateR1<float>({6, 18});
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, ReduceWindowMax) {
@@ -1370,7 +1370,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
   std::unique_ptr<Literal> result = Evaluate();
 
   auto expected = Literal::CreateR2<float>({{6, 7}});
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
@@ -1427,7 +1427,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
   std::unique_ptr<Literal> result = Evaluate();
 
   auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
@@ -1490,7 +1490,7 @@ TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
   std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
   std::unique_ptr<Literal> result_literal =
       Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
-  LiteralTestUtil::ExpectEqual(*result_literal, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
 }
 
 TEST_P(HloEvaluatorTest, StridedSlice) {
@@ -1523,7 +1523,7 @@ TEST_P(HloEvaluatorTest, StridedSlice) {
       {19},
   });
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, DynamicSlice) {
@@ -1556,7 +1556,7 @@ TEST_P(HloEvaluatorTest, DynamicSlice) {
       {6, 7, 8},
   });
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 // Verifies that the HloEvaluator's implementation goes along with existing
@@ -1591,7 +1591,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceModSlice) {
       {6, 7, 8},
   });
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
@@ -1627,7 +1627,7 @@ TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
       {5, -6, -7},
   });
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, SetAndGetTuples) {
@@ -1662,7 +1662,7 @@ TEST_P(HloEvaluatorTest, SetAndGetTuples) {
       {5, 6, 7},
   });
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
@@ -1703,7 +1703,7 @@ TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
       result_inner_literal.get(),
   });
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, Reverse) {
@@ -1756,7 +1756,7 @@ TEST_P(HloEvaluatorTest, Reverse) {
   });
   // clang-format on
 
-  LiteralTestUtil::ExpectEqual(*expected, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
@@ -1776,8 +1776,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
       add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()},
             {square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
   TF_ASSERT_OK(result.status());
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}),
-                               *result.ValueOrDie());
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
 }
 
 // Check that EvaluateWithSubstitutions works if one of the operands to the op
@@ -1800,8 +1800,8 @@ TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) {
   auto result = evaluator.EvaluateWithSubstitutions(
       add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
   TF_ASSERT_OK(result.status());
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}),
-                               *result.ValueOrDie());
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
@@ -1823,9 +1823,9 @@ ENTRY main {
   std::unique_ptr<Literal> operand =
       Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
-  LiteralTestUtil::ExpectEqual(
-      *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+                             *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
@@ -1847,9 +1847,9 @@ ENTRY main {
   std::unique_ptr<Literal> operand =
       Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       *Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+      *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
@@ -1872,10 +1872,10 @@ ENTRY main {
       Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   std::unique_ptr<Literal> gather_indices =
       Literal::CreateR2<int32>({{0, 2}, {2, 1}});
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       *Literal::CreateR3<int32>(
           {{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+      *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
@@ -1900,9 +1900,9 @@ ENTRY main {
                                 {{-7, 7}, {-8, 8}, {-9, 9}}});
   std::unique_ptr<Literal> gather_indices =
       Literal::CreateR2<int32>({{0, 0}, {1, 0}});
-  LiteralTestUtil::ExpectEqual(
-      *Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+                             *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 TEST_P(HloEvaluatorTest,
@@ -1928,9 +1928,9 @@ ENTRY main {
                                 {{-7, 7}, {-8, 8}, {-9, 9}}});
   std::unique_ptr<Literal> gather_indices =
       Literal::CreateR2<int32>({{0, 0}, {1, 0}});
-  LiteralTestUtil::ExpectEqual(
-      *Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+                             *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
@@ -1952,9 +1952,9 @@ ENTRY main {
   std::unique_ptr<Literal> operand =
       Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
-  LiteralTestUtil::ExpectEqual(
-      *Literal::CreateR2<int32>({{5}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{5}}),
+                             *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
@@ -1977,9 +1977,9 @@ ENTRY main {
       Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
   std::unique_ptr<Literal> gather_indices =
       Literal::CreateR2<int32>({{2, 1}, {1, 1}});
-  LiteralTestUtil::ExpectEqual(
-      *Literal::CreateR3<int32>({{{8}}, {{5}}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{8}}, {{5}}}),
+                             *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
@@ -2000,9 +2000,9 @@ ENTRY main {
   ParseAndVerifyModule(hlo_text);
   std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
   std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
-  LiteralTestUtil::ExpectEqual(
-      *Literal::CreateR2<int32>({{}, {}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{}, {}}),
+                             *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
@@ -2025,9 +2025,9 @@ ENTRY main {
   std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2});
   std::unique_ptr<Literal> gather_indices =
       Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
-  LiteralTestUtil::ExpectEqual(
-      *Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
-      *Evaluate({operand.get(), gather_indices.get()}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
+                             *Evaluate({operand.get(), gather_indices.get()})));
 }
 
 // Verifies that HloEvaluator evaluates a HLO instruction that performs
index 7aa1c7c..d2af261 100644 (file)
@@ -71,7 +71,7 @@ TEST_F(InlinerTest, MapMax) {
   // Verify execution on CPU.
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
   auto expected = Literal::CreateR1<float>({4, 3, 3, 4});
-  LiteralTestUtil::ExpectEqual(*result, *expected);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
 }
 
 // Test that `constant` function is changed to `broadcast`.
@@ -105,7 +105,7 @@ TEST_F(InlinerTest, MapConstant) {
   // Verify execution on CPU.
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
   auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
-  LiteralTestUtil::ExpectEqual(*result, *expected);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
 }
 
 TEST_F(InlinerTest, MapSubtractOppositeOrder) {
@@ -143,7 +143,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
   // Verify execution on CPU.
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
   auto expected = Literal::CreateR1<float>({3, 1, -1, -3});
-  LiteralTestUtil::ExpectEqual(*result, *expected);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
 }
 
 
index b982cf0..4b0dfde 100644 (file)
@@ -87,6 +87,7 @@ cc_library(
         "//tensorflow/compiler/xla:array2d",
         "//tensorflow/compiler/xla:array3d",
         "//tensorflow/compiler/xla:array4d",
+        "//tensorflow/compiler/xla:literal_comparison",
         "//tensorflow/compiler/xla:literal_util",
         "//tensorflow/compiler/xla:shape_util",
         "//tensorflow/compiler/xla:test",
index a180cdd..51b9f0d 100644 (file)
@@ -46,8 +46,8 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(42.0), *result,
-                              error_spec_);
+  EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0<float>(42.0), *result,
+                                    error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
@@ -62,9 +62,9 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  LiteralTestUtil::ExpectNear(
+  EXPECT_TRUE(LiteralTestUtil::Near(
       *Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
-      error_spec_);
+      error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
@@ -85,13 +85,13 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  LiteralTestUtil::ExpectNear(
+  EXPECT_TRUE(LiteralTestUtil::Near(
       *Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
-      LiteralSlice(*result, {0}), error_spec_);
+      LiteralSlice(*result, {0}), error_spec_));
 
-  LiteralTestUtil::ExpectNear(
+  EXPECT_TRUE(LiteralTestUtil::Near(
       *Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
-      LiteralSlice(*result, {1}), error_spec_);
+      LiteralSlice(*result, {1}), error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
@@ -106,9 +106,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  LiteralTestUtil::ExpectNear(
-      *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
-      error_spec_);
+  EXPECT_TRUE(
+      LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+                            *result, error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
@@ -125,9 +125,9 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  LiteralTestUtil::ExpectNear(
-      *Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
-      error_spec_);
+  EXPECT_TRUE(
+      LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}),
+                            *result, error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
@@ -142,10 +142,10 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  LiteralTestUtil::ExpectNear(
+  EXPECT_TRUE(LiteralTestUtil::Near(
       *Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
                                  {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
-      *result, error_spec_);
+      *result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
@@ -166,8 +166,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
   Array2D<float> pz({{1, 2}, {1, 2}});
   expected.FillWithPZ(pz);
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
-                              *result, error_spec_);
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
@@ -196,8 +196,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
   }
   expected.FillWithYX(yx);
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
-                              *result, error_spec_);
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
 }
 
 XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
@@ -218,8 +218,8 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result,
-                              error_spec_);
+  EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array),
+                                    *result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
@@ -238,8 +238,8 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
   Array4D<float> expected(64, 64, 3, 3);
   expected.Fill(1.0f);
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
-                              *result, error_spec_);
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
@@ -260,8 +260,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
   Array4D<float> expected(3, 3, 2, 2);
   expected.FillWithYX(to_broadcast);
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
-                              *result, error_spec_);
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
 }
 
 TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
@@ -291,8 +291,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
   hlo_module->AddEntryComputation(builder.Build());
   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
-                              *result, error_spec_);
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
 }
 
 }  // namespace
index 41f9a5f..be542c1 100644 (file)
@@ -297,7 +297,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
   std::unique_ptr<Literal> converted_expected;
   Shape layout_shape;
   if (use_bfloat16_) {
-    converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
+    converted_expected = Literal::ConvertF32ToBF16(expected);
     expected_ptr = converted_expected.get();
     if (shape_with_layout != nullptr) {
       layout_shape = *shape_with_layout;
@@ -311,7 +311,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
     }
   }
   auto expect_equal = [&](const Literal& actual, const string& error_message) {
-    LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message);
+    EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
   };
   if (execution_options_.debug_options().xla_test_all_output_layouts()) {
     return ComputeAndCompareLiteralWithAllOutputLayouts(
@@ -323,7 +323,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
   }
   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
                                                       shape_with_layout));
-  LiteralTestUtil::ExpectEqual(*expected_ptr, *actual);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
   return tensorflow::Status::OK();
 }
 
@@ -349,7 +349,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
   std::unique_ptr<Literal> converted_expected;
   Shape layout_shape;
   if (use_bfloat16_) {
-    converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
+    converted_expected = Literal::ConvertF32ToBF16(expected);
     expected_ptr = converted_expected.get();
     if (shape_with_layout != nullptr) {
       layout_shape = *shape_with_layout;
@@ -363,7 +363,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
     }
   }
   auto expect_near = [&](const Literal& actual, const string& error_message) {
-    LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message);
+    EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
+        << error_message;
   };
   if (execution_options_.debug_options().xla_test_all_output_layouts()) {
     return ComputeAndCompareLiteralWithAllOutputLayouts(
@@ -375,7 +376,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
   }
   TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
                                                       shape_with_layout));
-  LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error);
+  EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
   return tensorflow::Status::OK();
 }
 
@@ -407,7 +408,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
     return;
   }
   auto actual = actual_status.ConsumeValueOrDie();
-  LiteralTestUtil::ExpectEqual(expected, *actual);
+  EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
 }
 
 void ClientLibraryTestBase::ComputeAndCompareTuple(
@@ -419,7 +420,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple(
     return;
   }
   auto actual = actual_status.ConsumeValueOrDie();
-  LiteralTestUtil::ExpectNear(expected, *actual, error);
+  EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
 }
 
 void ClientLibraryTestBase::ComputeAndCompare(
@@ -431,7 +432,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
   }
   std::unique_ptr<Literal> reference, result;
   std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
-  LiteralTestUtil::ExpectEqual(*reference, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
 }
 
 void ClientLibraryTestBase::ComputeAndCompare(
@@ -444,7 +445,7 @@ void ClientLibraryTestBase::ComputeAndCompare(
   }
   std::unique_ptr<Literal> reference, result;
   std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
-  LiteralTestUtil::ExpectNear(*reference, *result, error);
+  EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
 }
 
 StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
@@ -562,7 +563,7 @@ XlaOp ClientLibraryTestBase::AddParam(const Literal& argument,
 XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
                                                        XlaBuilder* builder) {
   return builder->ConstantLiteral(
-      use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
+      use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal);
 }
 
 std::unique_ptr<GlobalData>
@@ -583,7 +584,7 @@ ClientLibraryTestBase::CreateParameterAndTransferLiteral(
   const Literal* param_literal = &literal;
   std::unique_ptr<Literal> converted_literal;
   if (use_bfloat16_) {
-    converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
+    converted_literal = Literal::ConvertF32ToBF16(literal);
     param_literal = converted_literal.get();
   }
   std::unique_ptr<GlobalData> data =
index 16e838e..c8c3af0 100644 (file)
@@ -541,7 +541,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR0Parameter(
     XlaBuilder* builder, XlaOp* data_handle) {
   std::unique_ptr<Literal> literal = Literal::CreateR0(value);
   if (use_bfloat16_ && literal->shape().element_type() == F32) {
-    literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+    literal = Literal::ConvertF32ToBF16(*literal);
   }
   std::unique_ptr<GlobalData> data =
       client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -555,7 +555,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
   std::unique_ptr<Literal> literal = Literal::CreateR1(values);
   if (use_bfloat16_ && literal->shape().element_type() == F32) {
-    literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+    literal = Literal::ConvertF32ToBF16(*literal);
   }
   std::unique_ptr<GlobalData> data =
       client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -569,7 +569,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
   std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
   if (use_bfloat16_ && literal->shape().element_type() == F32) {
-    literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+    literal = Literal::ConvertF32ToBF16(*literal);
   }
   std::unique_ptr<GlobalData> data =
       client_->TransferToServer(*literal).ConsumeValueOrDie();
@@ -583,7 +583,7 @@ std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR3Parameter(
     const string& name, XlaBuilder* builder, XlaOp* data_handle) {
   std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
   if (use_bfloat16_ && literal->shape().element_type() == F32) {
-    literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+    literal = Literal::ConvertF32ToBF16(*literal);
   }
   std::unique_ptr<GlobalData> data =
       client_->TransferToServer(*literal).ConsumeValueOrDie();
index abf7312..08671cf 100644 (file)
@@ -62,9 +62,9 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
       TF_ASSERT_OK_AND_ASSIGN(
           auto computed, client_->Transfer(*data, &expected_literal->shape()));
 
-      LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
-                                                   computed->shape());
-      LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+      ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
+          expected_literal->shape(), computed->shape()));
+      EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
     }
   }
 }
@@ -142,7 +142,7 @@ XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
       auto result_literal,
       client_->Transfer(*results[0], &expected_result->shape()));
 
-  LiteralTestUtil::ExpectEqual(*expected_result, *result_literal);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
 }
 
 }  // namespace
index ecce599..e1aa9d7 100644 (file)
@@ -50,8 +50,8 @@ class CompilationCacheTest : public ClientLibraryTestBase {
                                  /*execution_options=*/&execution_options_,
                                  &execution_profile)
             .ConsumeValueOrDie();
-    LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(expected_result),
-                                *result, error_spec_);
+    EXPECT_TRUE(LiteralTestUtil::Near(
+        *Literal::CreateR0<float>(expected_result), *result, error_spec_));
     EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
   }
 
@@ -67,8 +67,8 @@ class CompilationCacheTest : public ClientLibraryTestBase {
                            .ConsumeValueOrDie();
     std::unique_ptr<Literal> result =
         client_->Transfer(*data_handle).ConsumeValueOrDie();
-    LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>(expected_result),
-                                *result, error_spec_);
+    EXPECT_TRUE(LiteralTestUtil::Near(
+        *Literal::CreateR2<float>(expected_result), *result, error_spec_));
     EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
   }
 
index bf4b8fb..ba22530 100644 (file)
@@ -208,7 +208,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) {
                             ComputeConstantLiteral(client, computation, &b));
     std::unique_ptr<Literal> expected_literal =
         Literal::CreateR1<int32>({4, 6});
-    LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+    EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
   }
 }
 
@@ -222,7 +222,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) {
     TF_ASSERT_OK_AND_ASSIGN(auto computed,
                             ComputeConstantLiteral(client, computation, &b));
     std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
-    LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+    EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
   }
 }
 
@@ -244,9 +244,9 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
       std::unique_ptr<Literal> expected_literal =
           Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
                                              LayoutUtil::MakeLayout(layout));
-      LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
-                                                   computed->shape());
-      LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+      ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
+          expected_literal->shape(), computed->shape()));
+      EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
     }
   }
 }
index 155fbac..2b3390c 100644 (file)
@@ -49,7 +49,7 @@ class CopyOpTest : public HloTestBase {
     module->AddEntryComputation(std::move(computation));
 
     std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
-    LiteralTestUtil::ExpectEqual(literal, *result);
+    EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
   }
 
   void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
@@ -253,7 +253,7 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) {
 
   auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
                     .ConsumeValueOrDie();
-  LiteralTestUtil::ExpectEqual(*empty, *actual);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
 }
 
 }  // namespace
index b947f82..e6f79b5 100644 (file)
@@ -118,9 +118,9 @@ class FusionTest : public HloTestBase {
     auto expected = Literal::CreateR2FromArray2D(answer_data);
     auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
     if (primitive_util::IsFloatingPointType(prim_type)) {
-      LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
+      EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
     } else {
-      LiteralTestUtil::ExpectEqual(*expected, *actual);
+      EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
     }
   }
 
@@ -221,9 +221,9 @@ XLA_TEST_F(FusionTest, Test) {
            const4, reshape3, add2, const1, const0},
           HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}),
-                              *ExecuteAndTransfer(std::move(hlo_module), {}),
-                              ErrorSpec(1e-4));
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      *Literal::CreateR2<float>({{0.5}, {2.72}}),
+      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
 }
 
 // Test whether we emit appropriate code for parameters of fusion instructions.
@@ -247,9 +247,9 @@ XLA_TEST_F(FusionTest, Parameter) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
-                              *ExecuteAndTransfer(std::move(hlo_module), {}),
-                              ErrorSpec(1e-4));
+  EXPECT_TRUE(LiteralTestUtil::Near(
+      *Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
 }
 
 XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
@@ -307,9 +307,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectNear(
+  EXPECT_TRUE(LiteralTestUtil::Near(
       *Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
+      *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
 }
 
 XLA_TEST_F(FusionTest, ReshapeToScalar) {
@@ -322,8 +322,9 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR0<int32>(5),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
@@ -336,9 +337,9 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       *Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}));
+      *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
@@ -351,9 +352,9 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       *Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}));
+      *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
@@ -366,8 +367,9 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape__1by1by1) {
@@ -380,8 +382,9 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{7}}}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape__) {
@@ -394,8 +397,9 @@ XLA_TEST_F(FusionTest, Reshape__) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
@@ -408,9 +412,9 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       *Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}));
+      *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Transpose_2by3) {
@@ -423,9 +427,9 @@ XLA_TEST_F(FusionTest, Transpose_2by3) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       *Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}));
+      *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Transpose_3by3) {
@@ -438,9 +442,9 @@ XLA_TEST_F(FusionTest, Transpose_3by3) {
   hlo_module->AddEntryComputation(builder.Build())
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
                                 HloInstruction::FusionKind::kLoop);
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       *Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}));
+      *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Reverse) {
@@ -454,8 +458,9 @@ XLA_TEST_F(FusionTest, Reverse) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR1<int32>({3, 2, 1}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, ReverseNegate) {
@@ -471,8 +476,9 @@ XLA_TEST_F(FusionTest, ReverseNegate) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-3, -2, -1}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-3, -2, -1}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, BroadcastNegate) {
@@ -488,8 +494,9 @@ XLA_TEST_F(FusionTest, BroadcastNegate) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -1}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -1}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, SliceNegate) {
@@ -505,8 +512,9 @@ XLA_TEST_F(FusionTest, SliceNegate) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -3}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -3}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, DynamicSliceNegate) {
@@ -526,8 +534,9 @@ XLA_TEST_F(FusionTest, DynamicSliceNegate) {
           /*instructions_to_fuse=*/{negate3, dynamic_slice2},
           HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-2, -3}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-2, -3}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, ReshapeNegate) {
@@ -543,8 +552,9 @@ XLA_TEST_F(FusionTest, ReshapeNegate) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 // TODO(b/64070202): Investigate failure.
@@ -561,8 +571,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_GPU(TransposeNegate)) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 std::unique_ptr<HloComputation> MakeReduceTestComputation() {
@@ -591,8 +602,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR0<int32>(15),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
@@ -612,8 +624,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(-15),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR0<int32>(-15),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
@@ -661,9 +674,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
       ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
                                 HloInstruction::FusionKind::kLoop);
 
-  LiteralTestUtil::ExpectEqual(
+  EXPECT_TRUE(LiteralTestUtil::Equal(
       *Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
-      *ExecuteAndTransfer(std::move(hlo_module), {}));
+      *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 // When a constant (or other op) which has multiple users is imported
@@ -697,8 +710,9 @@ XLA_TEST_F(FusionTest, SharedConstant) {
   // fused instruction contains the constant(2), the parameter, and 4 adds
   EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
 
-  LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
-                               *ExecuteAndTransfer(std::move(hlo_module), {}));
+  EXPECT_TRUE(
+      LiteralTestUtil::Equal(*Literal::CreateR1<int32>({8}),
+                             *ExecuteAndTransfer(std::move(hlo_module), {})));
 }
 
 XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
index 130456e..4854c64 100644 (file)
@@ -629,8 +629,8 @@ XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
       client_->ExecuteParallel(computation_instances));
   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
                           client_->Transfer(*(result_data[0])));
-  LiteralTestUtil::ExpectEqual(
-      *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}));
+  EXPECT_TRUE(LiteralTestUtil::Equal(
+      *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}})));
 }
 }  // namespace
 }  // namespace xla
index 868876c..c38a78d 100644 (file)
@@ -21,6 +21,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/index_util.h"
 #include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_comparison.h"
 #include "tensorflow/compiler/xla/literal_util.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/shape_util.h"
@@ -46,117 +47,21 @@ using ::tensorflow::strings::StrCat;
 
 /* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
     const Shape& expected, const Shape& actual) {
-  if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
-    return ::testing::AssertionFailure()
-           << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected)
-           << " got: " << ShapeUtil::HumanString(actual);
-  }
-  if (ShapeUtil::IsTuple(expected)) {
-    if (ShapeUtil::TupleElementCount(expected) !=
-        ShapeUtil::TupleElementCount(actual)) {
-      return ::testing::AssertionFailure()
-             << "want tuple element count: "
-             << ShapeUtil::TupleElementCount(expected)
-             << " got tuple element count: "
-             << ShapeUtil::TupleElementCount(actual);
-    }
-    for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
-      ::testing::AssertionResult result =
-          EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i))
-          << "mismatch in tuple index " << i;
-      if (!result) {
-        return result;
-      }
-    }
-  } else {
-    if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
-      return ::testing::AssertionFailure()
-             << "want rank of: " << ShapeUtil::HumanString(expected)
-             << " got rank of: " << ShapeUtil::HumanString(actual);
-    }
-    if (expected.element_type() != actual.element_type()) {
-      return ::testing::AssertionFailure()
-             << PrimitiveType_Name(expected.element_type()) << " vs "
-             << PrimitiveType_Name(actual.element_type());
-    }
-    if (expected.dimensions_size() != actual.dimensions_size()) {
-      return ::testing::AssertionFailure()
-             << "want dimensions_size " << expected.dimensions_size()
-             << " got dimensions_size " << actual.dimensions_size();
-    }
-    for (int i = 0; i < expected.dimensions_size(); ++i) {
-      if (expected.dimensions(i) != actual.dimensions(i)) {
-        return ::testing::AssertionFailure()
-               << "mismatch in dimension #" << i
-               << " expected: " << ShapeUtil::HumanString(expected)
-               << " actual: " << ShapeUtil::HumanString(actual);
-      }
-    }
+  Status result = literal_comparison::EqualShapes(expected, actual);
+  if (result.ok()) {
+    return ::testing::AssertionSuccess();
   }
-  return ::testing::AssertionSuccess();
-}
-
-/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
-                                                     const Shape& actual) {
-  ASSERT_TRUE(EqualShapes(expected, actual));
+  return ::testing::AssertionFailure() << result;
 }
 
-/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
+/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts(
     const Shape& expected, const Shape& actual) {
-  ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
-}
-
-namespace {
-
-// Return a literal with all arrays of type FromNativeT converted to type
-// ToNativeT in the given literal.
-template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
-  // First construct shape of the result.
-  Shape result_shape(literal.shape());
-  ShapeUtil::ForEachMutableSubshape(
-      &result_shape, [](Shape* subshape, const ShapeIndex&) {
-        if (subshape->element_type() ==
-            primitive_util::NativeToPrimitiveType<FromNativeT>()) {
-          subshape->set_element_type(
-              primitive_util::NativeToPrimitiveType<ToNativeT>());
-        }
-      });
-  auto result = MakeUnique<Literal>(result_shape);
-
-  // Then copy over the data from 'literal' converting FromNativeT values to
-  // ToNativeT values as necessary.
-  ShapeUtil::ForEachSubshape(
-      literal.shape(),
-      [&](const Shape& subshape, const ShapeIndex& shape_index) {
-        if (ShapeUtil::IsArray(subshape)) {
-          if (subshape.element_type() ==
-              primitive_util::NativeToPrimitiveType<FromNativeT>()) {
-            auto src = literal.data<FromNativeT>(shape_index);
-            auto dest = result->data<ToNativeT>(shape_index);
-            for (int64 i = 0; i < src.size(); ++i) {
-              dest[i] = static_cast<ToNativeT>(src[i]);
-            }
-          } else {
-            TF_CHECK_OK(result->CopyFrom(literal,
-                                         /*dest_shape_index=*/shape_index,
-                                         /*src_shape_index=*/shape_index));
-          }
-        }
-      });
-  return result;
-}
-
-}  // namespace
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
-    LiteralSlice literal) {
-  return ConvertType<bfloat16, float>(literal);
-}
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
-    LiteralSlice literal) {
-  return ConvertType<float, bfloat16>(literal);
+  if (expected.ShortDebugString() != actual.ShortDebugString()) {
+    return ::testing::AssertionFailure()
+           << "want: " << expected.ShortDebugString()
+           << " got: " << actual.ShortDebugString();
+  }
+  return ::testing::AssertionSuccess();
 }
 
 namespace {
@@ -168,183 +73,15 @@ string Hostname() {
   return string(hostname);
 }
 
-// Helper function for comparing a floating point type, FloatT, bitwise equal
-// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
-// -- on miscompare, a nice error message is given in the AssertionFailure.
-template <typename FloatT, typename UnsignedT>
-::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
-  auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
-  auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
-  auto lhs_double = static_cast<double>(lhs);
-  auto rhs_double = static_cast<double>(rhs);
-  if (ulhs != urhs) {
-    return ::testing::AssertionFailure() << Printf(
-               "floating values are not bitwise-equal; and equality testing "
-               "was requested: %s=%g=%a vs %s=%g=%a",
-               StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double,
-               lhs_double, StrCat(tensorflow::strings::Hex(urhs)).c_str(),
-               rhs_double, rhs_double);
-  }
-  return ::testing::AssertionSuccess();
-}
-
-// Templated comparator that specializes for float equality comparison with the
-// bitwise helper above (this is the un-specialized fallback, to just use the
-// default gunit implementation).
-template <typename NativeT>
-::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) {
-  if (lhs == rhs) {
-    return ::testing::AssertionSuccess();
-  }
-  ::testing::Message msg;
-  msg << "Expected equality of these values:";
-  msg << "\n  " << lhs;
-  msg << "\n  " << rhs;
-
-  return ::testing::AssertionFailure() << msg;
-}
-
-// Specializations for floating types that do bitwise comparisons when equality
-// comparison is requested.
-template <>
-::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
-  return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<Eigen::half>(Eigen::half lhs,
-                                                     Eigen::half rhs) {
-  return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
-  return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
-  return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<complex64>(complex64 lhs,
-                                                   complex64 rhs) {
-  auto res = CompareEqual<float>(lhs.real(), rhs.real());
-  if (!res) {
-    return res;
-  }
-  return CompareEqual<float>(lhs.imag(), rhs.imag());
-}
-
-// A recursive function which iterates through every index of expected and
-// actual literal and compares their values elementwise. Returns true if all
-// elements are equal.
-template <typename NativeT>
-bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual,
-                         tensorflow::gtl::MutableArraySlice<int64> multi_index,
-                         int64 dimension) {
-  if (dimension == expected.shape().dimensions_size()) {
-    NativeT expected_value = expected.Get<NativeT>(multi_index);
-    NativeT actual_value = actual.Get<NativeT>(multi_index);
-    ::testing::AssertionResult result =
-        CompareEqual<NativeT>(expected_value, actual_value);
-    return result;  // Defines implicit coersion to bool.
-  }
-
-  bool all_match = true;
-  for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
-    multi_index[dimension] = i;
-    all_match = all_match && ExpectLiteralsEqual<NativeT>(
-                                 expected, actual, multi_index, dimension + 1);
-  }
-  return all_match;
-}
-
 }  // namespace
 
-/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected,
-                                               LiteralSlice actual,
-                                               const string& message) {
-  EXPECT_TRUE(Equal(expected, actual))
-      << "expected:\n"
-      << expected.ToString() << "\n\tvs actual:\n"
-      << actual.ToString()
-      << (message.empty() ? "" : StrCat("\nmessage: ", message));
-}
-
-/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected,
-                                                  LiteralSlice actual) {
-  EXPECT_FALSE(Equal(expected, actual));
-}
-
 /* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
-    LiteralSlice expected, LiteralSlice actual) {
-  VLOG(1) << "expected:";
-  XLA_VLOG_LINES(1, expected.ToString());
-  VLOG(1) << "actual:";
-  XLA_VLOG_LINES(1, actual.ToString());
-
-  AssertEqualShapes(expected.shape(), actual.shape());
-  std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
-  bool match = false;
-  switch (expected.shape().element_type()) {
-    case PRED:
-      match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0);
-      break;
-    case U8:
-      match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0);
-      break;
-    case S32:
-      match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0);
-      break;
-    case S64:
-      match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0);
-      break;
-    case U32:
-      match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0);
-      break;
-    case U64:
-      match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
-      break;
-    case BF16:
-      match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
-      break;
-    case F16:
-      match = ExpectLiteralsEqual<half>(expected, actual, &multi_index, 0);
-      break;
-    case F32:
-      match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
-      break;
-    case F64:
-      match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
-      break;
-    case C64:
-      match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0);
-      break;
-    case TUPLE: {
-      bool tuple_match = true;
-      for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
-        SCOPED_TRACE(StrCat("Tuple index ", i, " in ",
-                            ShapeUtil::HumanString(expected.shape())));
-
-        // Create LiteralSlices of the expected and actual elements.
-        auto result =
-            Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}));
-        tuple_match = tuple_match ? !!result : false;
-      }
-      match = tuple_match;
-      break;
-    }
-    default:
-      LOG(FATAL)
-          << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
-          << PrimitiveType_Name(expected.shape().element_type());
-  }
-  ::testing::AssertionResult result = ::testing::AssertionSuccess();
-  if (!match) {
-    result = ::testing::AssertionFailure()
-             << "expected: " << expected.ToString()
-             << "\nactual:   " << actual.ToString();
-    VLOG(1) << result.message();
+    const LiteralSlice& expected, const LiteralSlice& actual) {
+  Status result = literal_comparison::Equal(expected, actual);
+  if (result.ok()) {
+    return ::testing::AssertionSuccess();
   }
-  return result;
+  return ::testing::AssertionFailure() << result;
 }
 
 namespace {
@@ -368,7 +105,7 @@ int64 RecursiveElementCount(const Shape& shape) {
 // 3 minutes.  The utility of printing a literal with >1000 elements is
 // questionable, especially when writing the Literal proto to disk is orders
 // of magnitude faster.
-string TruncateHugeLiteral(LiteralSlice literal) {
+string TruncateHugeLiteral(const LiteralSlice& literal) {
   return RecursiveElementCount(literal.shape()) < 1000
              ? literal.ToString()
              : "[TRUNCATED, Literal with more than 1000 values]";
@@ -435,8 +172,8 @@ class NearComparator {
   // result. The assertion result is successful if all actual and expected
   // elements are within the given error bound. In case of error, the assertion
   // result contains a detailed error message in case of failure.
-  static ::testing::AssertionResult Compare(LiteralSlice expected,
-                                            LiteralSlice actual,
+  static ::testing::AssertionResult Compare(const LiteralSlice& expected,
+                                            const LiteralSlice& actual,
                                             ErrorSpec error,
                                             bool detailed_message) {
     NearComparator<NativeT> comparator(expected, actual, error,
@@ -464,7 +201,7 @@ class NearComparator {
       return Printf(
           "actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
           FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
-          LiteralTestUtil::MultiIndexAsString(
+          Literal::MultiIndexAsString(
               IndexUtil::LinearIndexToMultidimensionalIndex(shape,
                                                             linear_index))
               .c_str(),
@@ -472,8 +209,9 @@ class NearComparator {
     }
   };
 
-  explicit NearComparator(LiteralSlice expected, LiteralSlice actual,
-                          ErrorSpec error, bool detailed_message)
+  explicit NearComparator(const LiteralSlice& expected,
+                          const LiteralSlice& actual, ErrorSpec error,
+                          bool detailed_message)
       : expected_(expected),
         actual_(actual),
         error_(error),
@@ -649,7 +387,7 @@ class NearComparator {
   }
 
   // Writes the given literal to a file in the test temporary directory.
-  void WriteLiteralToTempFile(LiteralSlice literal, const string& name) {
+  void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
     int64 now_usec = tensorflow::Env::Default()->NowMicros();
     string filename = tensorflow::io::JoinPath(
         tensorflow::testing::TmpDir(),
@@ -794,8 +532,8 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 // Helper function for comparing two literals for nearness. Handles tuple-shapes
 // via recursion. shape_index is the ShapeIndex of expected (or actual)
 // currently being compared.
-::testing::AssertionResult NearHelper(LiteralSlice expected,
-                                      LiteralSlice actual,
+::testing::AssertionResult NearHelper(const LiteralSlice& expected,
+                                      const LiteralSlice& actual,
                                       const ErrorSpec& error,
                                       bool detailed_message,
                                       const ShapeIndex& shape_index) {
@@ -874,30 +612,14 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
 }  // namespace
 
 /* static */ ::testing::AssertionResult LiteralTestUtil::Near(
-    LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
-    bool detailed_message) {
+    const LiteralSlice& expected, const LiteralSlice& actual,
+    const ErrorSpec& error, bool detailed_message) {
   return NearHelper(expected, actual, error, detailed_message,
                     /*shape_index=*/{});
 }
 
-/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected,
-                                              LiteralSlice actual,
-                                              const ErrorSpec& error,
-                                              const string& message) {
-  ::testing::AssertionResult res =
-      Near(expected, actual, error, /*detailed_message=*/false);
-  if (!res) {
-    res << "Expected: " << TruncateHugeLiteral(expected) << "\n";
-    res << "Actual: " << TruncateHugeLiteral(actual) << "\n";
-    if (!message.empty()) {
-      res << StrCat("\nmessage: ", message);
-    }
-  }
-  EXPECT_TRUE(res);
-}
-
-/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
-    LiteralSlice expected, LiteralSlice actual,
+/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
+    const LiteralSlice& expected, const LiteralSlice& actual,
     const tensorflow::gtl::optional<ErrorSpec>& error) {
   if (error.has_value()) {
     VLOG(1) << "Expects near";
@@ -907,86 +629,4 @@ constexpr std::array<float, 5> NearComparator<NativeT>::kErrorBucketBounds;
   return Equal(expected, actual);
 }
 
-/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
-    LiteralSlice expected, LiteralSlice actual,
-    const tensorflow::gtl::optional<ErrorSpec>& error) {
-  EXPECT_TRUE(NearOrEqual(expected, actual, error));
-}
-
-/* static */ string LiteralTestUtil::MultiIndexAsString(
-    tensorflow::gtl::ArraySlice<int64> multi_index) {
-  return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
-}
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
-    tensorflow::gtl::ArraySlice<int64> new_dimensions,
-    tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal) {
-  int64 new_num_elements = 1;
-  for (int64 i = 0; i < new_dimensions.size(); ++i) {
-    new_num_elements *= new_dimensions[i];
-  }
-  CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
-  CHECK_EQ(new_dimensions.size(), minor_to_major.size());
-
-  auto new_literal = MakeUnique<Literal>(
-      ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
-
-  // Create a new shape with the given minor-to-major layout. This shape is used
-  // solely for converting linear address to multi-dimensional addresses when
-  // writing elements to the new literal.
-  Shape shape_with_layout = new_literal->shape();
-  *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
-
-  // Copy data into new literal, element-by-element.
-  for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
-    std::vector<int64> from_multi_index =
-        IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
-    std::vector<int64> to_multi_index =
-        IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
-    switch (literal.shape().element_type()) {
-      case PRED:
-        new_literal->Set<bool>(to_multi_index,
-                               literal.Get<bool>(from_multi_index));
-        break;
-      case U8:
-        new_literal->Set<uint8>(to_multi_index,
-                                literal.Get<uint8>(from_multi_index));
-        break;
-      case U32:
-        new_literal->Set<uint32>(to_multi_index,
-                                 literal.Get<uint32>(from_multi_index));
-        break;
-      case S32:
-        new_literal->Set<int32>(to_multi_index,
-                                literal.Get<int32>(from_multi_index));
-        break;
-      case U64:
-        new_literal->Set<uint64>(to_multi_index,
-                                 literal.Get<uint64>(from_multi_index));
-        break;
-      case S64:
-        new_literal->Set<int64>(to_multi_index,
-                                literal.Get<int64>(from_multi_index));
-        break;
-      case F32:
-        new_literal->Set<float>(to_multi_index,
-                                literal.Get<float>(from_multi_index));
-        break;
-      case F64:
-        new_literal->Set<double>(to_multi_index,
-                                 literal.Get<double>(from_multi_index));
-        break;
-      case C64:
-        new_literal->Set<complex64>(to_multi_index,
-                                    literal.Get<complex64>(from_multi_index));
-        break;
-      default:
-        LOG(FATAL) << "Unhandled primitive element type: "
-                   << PrimitiveType_Name(literal.shape().element_type());
-    }
-  }
-
-  return new_literal;
-}
-
 }  // namespace xla
index 4983ddd..c9cb851 100644 (file)
@@ -57,65 +57,47 @@ class LiteralTestUtil {
  public:
   // Asserts that the given shapes have the same rank, dimension sizes, and
   // primitive types.
-  static ::testing::AssertionResult EqualShapes(const Shape& expected,
-                                                const Shape& actual);
-  static void AssertEqualShapes(const Shape& expected, const Shape& actual);
+  static ::testing::AssertionResult EqualShapes(
+      const Shape& expected, const Shape& actual) MUST_USE_RESULT;
 
   // Asserts that the provided shapes are equal as defined in AssertEqualShapes
   // and that they have the same layout.
-  static void AssertEqualShapesAndLayouts(const Shape& expected,
-                                          const Shape& actual);
+  static ::testing::AssertionResult EqualShapesAndLayouts(
+      const Shape& expected, const Shape& actual) MUST_USE_RESULT;
 
-  // If the given literal's data type is bfloat16, converts it to a float
-  // literal; otherwise, returns a copy of it. If the literal is a tuple,
-  // recursively converts its elements.
-  static std::unique_ptr<Literal> ConvertBF16ToF32(LiteralSlice bf16_literal);
-
-  // If the given literal's data type is float, converts it to a bfloat16
-  // literal; otherwise, returns a copy of it. If the literal is a tuple,
-  // recursively converts its elements.
-  static std::unique_ptr<Literal> ConvertF32ToBF16(LiteralSlice f32_literal);
-
-  // Asserts that the expected and actual literals are (bitwise) equal for all
-  // elements in the literal. Also, asserts that the rank, dimensions sizes, and
-  // primitive type are equal.
-  static ::testing::AssertionResult Equal(
-      LiteralSlice expected, LiteralSlice actual) TF_MUST_USE_RESULT;
-
-  // Expects that expected and actual are Equal.
-  static void ExpectEqual(LiteralSlice expected, LiteralSlice actual,
-                          const string& message = "");
-
-  // Expects that expected and actual are Not Equal.
-  static void ExpectNotEqual(LiteralSlice expected, LiteralSlice actual);
+  static ::testing::AssertionResult Equal(const LiteralSlice& expected,
+                                          const LiteralSlice& actual)
+      TF_MUST_USE_RESULT;
 
   // Asserts the given literal are (bitwise) equal to given expected values.
   template <typename NativeT>
-  static void ExpectR0Equal(NativeT expected, LiteralSlice actual);
+  static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
+
   template <typename NativeT>
   static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
-                            LiteralSlice actual);
+                            const LiteralSlice& actual);
   template <typename NativeT>
   static void ExpectR2Equal(
       std::initializer_list<std::initializer_list<NativeT>> expected,
-      LiteralSlice actual);
+      const LiteralSlice& actual);
+
   template <typename NativeT>
   static void ExpectR3Equal(
       std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>
           expected,
-      LiteralSlice actual);
+      const LiteralSlice& actual);
 
   // Asserts the given literal are (bitwise) equal to given array.
   template <typename NativeT>
   static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
-                                   LiteralSlice actual);
+                                   const LiteralSlice& actual);
   template <typename NativeT>
   static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
-                                   LiteralSlice actual);
+                                   const LiteralSlice& actual);
   template <typename NativeT>
   static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
-                                   LiteralSlice actual);
+                                   const LiteralSlice& actual);
 
   // Asserts that the expected and actual literals are within the given error
   // bound for all elements. Also, asserts that the rank, dimensions sizes, and
@@ -133,183 +115,138 @@ class LiteralTestUtil {
   // If detailed_message is true, then the error message in the assertion result
   // will contain a more detailed breakdown of mismatches.
   static ::testing::AssertionResult Near(
-      LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
-      bool detailed_message = false) TF_MUST_USE_RESULT;
-
-  // Expects expected and actual to be Near with the given error.
-  static void ExpectNear(LiteralSlice expected, LiteralSlice actual,
-                         const ErrorSpec& error, const string& message = "");
+      const LiteralSlice& expected, const LiteralSlice& actual,
+      const ErrorSpec& error, bool detailed_message = false) TF_MUST_USE_RESULT;
 
   // Asserts the given literal are within the given error bound of the given
   // expected values. Only supported for floating point values.
   template <typename NativeT>
-  static void ExpectR0Near(NativeT expected, LiteralSlice actual,
+  static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
                            const ErrorSpec& error);
+
   template <typename NativeT>
   static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
-                           LiteralSlice actual, const ErrorSpec& error);
+                           const LiteralSlice& actual, const ErrorSpec& error);
+
   template <typename NativeT>
   static void ExpectR2Near(
       std::initializer_list<std::initializer_list<NativeT>> expected,
-      LiteralSlice actual, const ErrorSpec& error);
+      const LiteralSlice& actual, const ErrorSpec& error);
+
   template <typename NativeT>
   static void ExpectR3Near(
       std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>
           expected,
-      LiteralSlice actual, const ErrorSpec& error);
+      const LiteralSlice& actual, const ErrorSpec& error);
+
   template <typename NativeT>
   static void ExpectR4Near(
       std::initializer_list<std::initializer_list<
           std::initializer_list<std::initializer_list<NativeT>>>>
           expected,
-      LiteralSlice actual, const ErrorSpec& error);
+      const LiteralSlice& actual, const ErrorSpec& error);
 
   // Asserts the given literal are within the given error bound to the given
   // array. Only supported for floating point values.
   template <typename NativeT>
   static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
-                                  LiteralSlice actual, const ErrorSpec& error);
+                                  const LiteralSlice& actual,
+                                  const ErrorSpec& error);
+
   template <typename NativeT>
   static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
-                                  LiteralSlice actual, const ErrorSpec& error);
+                                  const LiteralSlice& actual,
+                                  const ErrorSpec& error);
+
   template <typename NativeT>
   static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
-                                  LiteralSlice actual, const ErrorSpec& error);
+                                  const LiteralSlice& actual,
+                                  const ErrorSpec& error);
 
   // If the error spec is given, returns whether the expected and the actual are
   // within the error bound; otherwise, returns whether they are equal. Tuples
   // will be compared recursively.
   static ::testing::AssertionResult NearOrEqual(
-      LiteralSlice expected, LiteralSlice actual,
+      const LiteralSlice& expected, const LiteralSlice& actual,
       const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
 
-  // If the error spec is given, expects the expected and the actual to be near;
-  // otherwise, expects them to be equal. Tuples will be compared recursively.
-  static void ExpectNearOrEqual(
-      LiteralSlice expected, LiteralSlice actual,
-      const tensorflow::gtl::optional<ErrorSpec>& error);
-
-  // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
-  // be returned for a 2-dimensional index with dimension 0 index equal to 7,
-  // dimension 1 equal to 8.
-  static string MultiIndexAsString(
-      tensorflow::gtl::ArraySlice<int64> multi_index);
-
-  // Creates a literal with a new shape with the given new dimensions using the
-  // data in the given input literal. For reshaping purposes the (flat) data
-  // buffer of the input literal is assumed to have the given minor_to_major
-  // layout order.
-  static std::unique_ptr<Literal> Reshape(
-      tensorflow::gtl::ArraySlice<int64> new_dimensions,
-      tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal);
-
-  // Creates a literal with the supplied shape, and uses the provided value
-  // generator to populate the literal's values.
-  // Returns the new literal object, or an error Status if failed.
-  template <
-      PrimitiveType type,
-      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
-  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
-      const Shape& shape,
-      const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
-
-  // Creates a literal with the supplied shape, and initializes the literal
-  // values using a normal distribution with given mean and stddev standard
-  // deviation, and using the engine as entropy generator.
-  // Returns the new literal object, or an error Status if failed.
-  template <
-      PrimitiveType type, typename E,
-      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
-  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
-      const Shape& shape, E* engine, T mean, T stddev);
-
-  // Creates a literal with the supplied shape, and initializes the literal
-  // values using a normal distribution with given mean and stddev standard
-  // deviation.
-  // Returns the new literal object, or an error Status if failed.
-  template <
-      PrimitiveType type,
-      typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
-  static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
-      const Shape& shape, T mean, T stddev);
-
  private:
   TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
 };
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
-                                                 LiteralSlice actual) {
-  ExpectEqual(*Literal::CreateR0<NativeT>(expected), actual);
+                                                 const LiteralSlice& actual) {
+  EXPECT_TRUE(Equal(*Literal::CreateR0<NativeT>(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR1Equal(
-    tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual) {
-  ExpectEqual(*Literal::CreateR1<NativeT>(expected), actual);
+    tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
+  EXPECT_TRUE(Equal(*Literal::CreateR1<NativeT>(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2Equal(
     std::initializer_list<std::initializer_list<NativeT>> expected,
-    LiteralSlice actual) {
-  ExpectEqual(*Literal::CreateR2<NativeT>(expected), actual);
+    const LiteralSlice& actual) {
+  EXPECT_TRUE(Equal(*Literal::CreateR2<NativeT>(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3Equal(
     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
         expected,
-    LiteralSlice actual) {
-  ExpectEqual(*Literal::CreateR3<NativeT>(expected), actual);
+    const LiteralSlice& actual) {
+  EXPECT_TRUE(Equal(*Literal::CreateR3<NativeT>(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
-    const Array2D<NativeT>& expected, LiteralSlice actual) {
-  ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual);
+    const Array2D<NativeT>& expected, const LiteralSlice& actual) {
+  EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
-    const Array3D<NativeT>& expected, LiteralSlice actual) {
-  ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual);
+    const Array3D<NativeT>& expected, const LiteralSlice& actual) {
+  EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
-    const Array4D<NativeT>& expected, LiteralSlice actual) {
-  ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual);
+    const Array4D<NativeT>& expected, const LiteralSlice& actual) {
+  EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
-                                                LiteralSlice actual,
+                                                const LiteralSlice& actual,
                                                 const ErrorSpec& error) {
-  ExpectNear(*Literal::CreateR0<NativeT>(expected), actual, error);
+  EXPECT_TRUE(Near(*Literal::CreateR0<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR1Near(
-    tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual,
+    tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
     const ErrorSpec& error) {
-  ExpectNear(*Literal::CreateR1<NativeT>(expected), actual, error);
+  EXPECT_TRUE(Near(*Literal::CreateR1<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2Near(
     std::initializer_list<std::initializer_list<NativeT>> expected,
-    LiteralSlice actual, const ErrorSpec& error) {
-  ExpectNear(*Literal::CreateR2<NativeT>(expected), actual, error);
+    const LiteralSlice& actual, const ErrorSpec& error) {
+  EXPECT_TRUE(Near(*Literal::CreateR2<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3Near(
     std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
         expected,
-    LiteralSlice actual, const ErrorSpec& error) {
-  ExpectNear(*Literal::CreateR3<NativeT>(expected), actual, error);
+    const LiteralSlice& actual, const ErrorSpec& error) {
+  EXPECT_TRUE(Near(*Literal::CreateR3<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
@@ -317,63 +254,29 @@ template <typename NativeT>
     std::initializer_list<std::initializer_list<
         std::initializer_list<std::initializer_list<NativeT>>>>
         expected,
-    LiteralSlice actual, const ErrorSpec& error) {
-  ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error);
+    const LiteralSlice& actual, const ErrorSpec& error) {
+  EXPECT_TRUE(Near(*Literal::CreateR4<NativeT>(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR2NearArray2D(
-    const Array2D<NativeT>& expected, LiteralSlice actual,
+    const Array2D<NativeT>& expected, const LiteralSlice& actual,
     const ErrorSpec& error) {
-  ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error);
+  EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR3NearArray3D(
-    const Array3D<NativeT>& expected, LiteralSlice actual,
+    const Array3D<NativeT>& expected, const LiteralSlice& actual,
     const ErrorSpec& error) {
-  ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error);
+  EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error));
 }
 
 template <typename NativeT>
 /* static */ void LiteralTestUtil::ExpectR4NearArray4D(
-    const Array4D<NativeT>& expected, LiteralSlice actual,
+    const Array4D<NativeT>& expected, const LiteralSlice& actual,
     const ErrorSpec& error) {
-  ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error);
-}
-
-template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(
-    const Shape& shape,
-    const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
-  using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
-  TF_RET_CHECK(shape.element_type() == type);
-  std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
-  TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
-      [&](tensorflow::gtl::ArraySlice<int64> indexes) {
-        return generator(indexes);
-      }));
-  return std::move(literal);
-}
-
-template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
-                                     T stddev) {
-  using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
-  std::normal_distribution<NativeT> generator(mean, stddev);
-  return CreateRandomLiteral<type, NativeT>(
-      shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
-        return generator(*engine);
-      });
-}
-
-template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
-  std::minstd_rand0 engine;
-  return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
+  EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error));
 }
 
 }  // namespace xla
index 9d619a7..bbac728 100644 (file)
@@ -34,7 +34,7 @@ TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
   std::unique_ptr<Literal> literal = Literal::MakeTuple({
       Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
   });
-  LiteralTestUtil::ExpectEqual(*literal, *literal);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
 }
 
 TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
@@ -97,6 +97,15 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
   }
 }
 
+TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
+  auto expected = Literal::CreateR1<int32>({1, 2, 3});
+  auto actual = Literal::CreateR1<int32>({4, 5, 6});
+  ::testing::AssertionResult result =
+      LiteralTestUtil::Equal(*expected, *actual);
+  EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
+  EXPECT_THAT(result.message(), ::testing::HasSubstr("actual:   {4, 5, 6}"));
+}
+
 TEST(LiteralTestUtilTest, NearComparatorR1) {
   auto a =
       Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
index 0a603f4..7778053 100644 (file)
@@ -108,7 +108,7 @@ class MultiOutputFusionTest : public HloTestBase {
     expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
     auto actual = ExecuteAndTransfer(
         std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
-    LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
+    EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
   }
 
   void RunTest1D(bool manual_fusion, int size) {
@@ -168,7 +168,7 @@ class MultiOutputFusionTest : public HloTestBase {
 
     Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
     auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
-    LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
+    EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
   }
 };
 
index 29a4f75..1a2de69 100644 (file)
@@ -273,11 +273,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
                                              &execution_options_));
   }
 
-  LiteralTestUtil::ExpectEqual(*result1, *result2);
-  LiteralTestUtil::ExpectEqual(*result1, *result3);
-  LiteralTestUtil::ExpectNotEqual(*result1, *result4);
-  LiteralTestUtil::ExpectNotEqual(*result4, *result5);
-  LiteralTestUtil::ExpectNotEqual(*result5, *result6);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
+  EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
+  EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
+  EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
+  EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
 }
 
 XLA_TEST_F(PrngTest, TenValuesN01) {
index d7462d5..a4580cd 100644 (file)
@@ -656,9 +656,9 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
   std::unique_ptr<Literal> expected =
       Literal::CreateR2FromArray2D<float>(expected_array);
   if (use_bfloat16()) {
-    expected = LiteralTestUtil::ConvertF32ToBF16(*expected);
+    expected = Literal::ConvertF32ToBF16(*expected);
   }
-  LiteralTestUtil::ExpectEqual(*expected, *actual);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
 }
 
 XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
@@ -731,7 +731,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
 
   std::unique_ptr<Literal> expected =
-      LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal);
+      Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
   ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
                            zero_error_spec_);
 }
@@ -753,7 +753,7 @@ XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
   builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
 
   std::unique_ptr<Literal> expected =
-      LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal);
+      Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
   ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
                            zero_error_spec_);
 }
@@ -817,7 +817,7 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
   // Since the reshape is a no-op, verify that it does not change the underlying
   // data.
   if (use_bfloat16()) {
-    auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal);
+    auto expected = Literal::ConvertF32ToBF16(*input_literal);
     EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
   } else {
     EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
@@ -886,7 +886,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) {
                   /*new_sizes=*/new_bounds);
 
   std::unique_ptr<Literal> expected =
-      LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+      Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
           ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
 
   // Specify the requested output shape explicitly to ensure that this reshape
@@ -915,7 +915,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
                   /*new_sizes=*/new_bounds);
 
   std::unique_ptr<Literal> expected =
-      LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+      Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
           ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
 
   // Specify the requested output shape explicitly to ensure that this reshape
@@ -944,7 +944,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
                   /*new_sizes=*/new_bounds);
 
   std::unique_ptr<Literal> expected =
-      LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+      Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
           ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
 
   // Specify the requested output shape explicitly to ensure that this reshape
@@ -974,7 +974,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
                   /*new_sizes=*/new_bounds);
 
   std::unique_ptr<Literal> expected =
-      LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+      Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
           ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
 
   // Specify the requested output shape explicitly to ensure that this reshape
@@ -1003,7 +1003,7 @@ XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
                   /*new_sizes=*/new_bounds);
 
   std::unique_ptr<Literal> expected =
-      LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal)
+      Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
           ->Relayout(input_literal->shape().layout());
 
   // Specify the requested output shape explicitly to ensure that this reshape
index 8cbfcc6..7cfca78 100644 (file)
@@ -100,7 +100,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
   EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
 
   std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
-  LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
 }
 
 TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
@@ -135,7 +135,7 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
   EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
 
   std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
-  LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
 }
 
 }  // namespace
index 32db45f..f334a8c 100644 (file)
@@ -41,7 +41,7 @@ class RoundTripTransferTest : public ClientLibraryTestBase {
         client_->TransferToServer(original).ConsumeValueOrDie();
     std::unique_ptr<Literal> result =
         client_->Transfer(*data).ConsumeValueOrDie();
-    LiteralTestUtil::ExpectEqual(original, *result);
+    EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
   }
 };
 
index f35bc43..308d3fc 100644 (file)
@@ -390,7 +390,7 @@ XLA_TEST_F(ScalarComputationsTest, DivU32s) {
                                      &execution_options_)
                 .ConsumeValueOrDie();
         auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
-        LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+        EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
       }
     }
   }
@@ -431,7 +431,7 @@ XLA_TEST_F(ScalarComputationsTest, RemU32s) {
                                      &execution_options_)
                 .ConsumeValueOrDie();
         auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
-        LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+        EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
       }
     }
   }
index e2067bc..0063e7a 100644 (file)
@@ -175,7 +175,7 @@ XLA_TEST_F(TransferManagerTest, TransferTuple) {
                           transfer_manager_->TransferLiteralFromDevice(
                               stream_executor_, device_buffer));
 
-  LiteralTestUtil::ExpectEqual(*literal, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
@@ -189,7 +189,7 @@ XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
                           transfer_manager_->TransferLiteralFromDevice(
                               stream_executor_, device_buffer));
 
-  LiteralTestUtil::ExpectEqual(*literal, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
@@ -209,7 +209,7 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
                           transfer_manager_->TransferLiteralFromDevice(
                               stream_executor_, device_buffer));
 
-  LiteralTestUtil::ExpectEqual(*literal, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
@@ -224,7 +224,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
                           transfer_manager_->TransferLiteralFromDevice(
                               stream_executor_, device_buffer));
 
-  LiteralTestUtil::ExpectEqual(*literal, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
 }
 
 XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
@@ -243,7 +243,7 @@ XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
                           transfer_manager_->TransferLiteralFromDevice(
                               stream_executor_, device_buffer));
 
-  LiteralTestUtil::ExpectEqual(*literal, *result);
+  EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
 }
 
 }  // namespace