xla::Literal::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
xla::Literal::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(
+ xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
{
xla::Literal::CreateR1<int32>({-7, -42});
std::unique_ptr<xla::Literal> expected =
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected, *actual_literal));
}
}
{output_base.get(), output_grad1.get(), output_grad2.get()});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
xla::Literal::CreateR1<int32>({4, 143});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
// Tests a simple graph that reads and writes a variable, with a
xla::Literal::CreateR1<int32>({26, 66, 34, 401});
std::unique_ptr<xla::Literal> expected_literal =
xla::Literal::MakeTuple({expected0.get(), expected1.get()});
- xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
} // namespace
)
cc_library(
+ name = "literal_comparison",
+ srcs = ["literal_comparison.cc"],
+ hdrs = ["literal_comparison.h"],
+ deps = [
+ ":literal_util",
+ ":util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "metric_table_report",
srcs = ["metric_table_report.cc"],
hdrs = ["metric_table_report.h"],
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal_comparison.h"
+
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+using tensorflow::strings::StrCat;
+
+namespace xla {
+namespace literal_comparison {
+namespace {
+
+// Helper function for comparing a floating point type, FloatT, bitwise equal
+// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
+// -- on miscompare, a nice error message is given in the AssertionFailure.
+template <typename FloatT, typename UnsignedT>
+Status CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
+ auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
+ auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
+ auto lhs_double = static_cast<double>(lhs);
+ auto rhs_double = static_cast<double>(rhs);
+ if (ulhs != urhs) {
+ return InvalidArgument(
+ "floating values are not bitwise-equal; and equality testing "
+ "was requested: %s=%g=%a vs %s=%g=%a",
+ StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double, lhs_double,
+ StrCat(tensorflow::strings::Hex(urhs)).c_str(), rhs_double, rhs_double);
+ }
+ return Status::OK();
+}
+
+// Templated comparator that specializes for float equality comparison with the
+// bitwise helper above (this is the un-specialized fallback, to just use the
+// default gunit implementation).
+template <typename NativeT>
+Status CompareEqual(NativeT lhs, NativeT rhs) {
+ if (lhs == rhs) {
+ return Status::OK();
+ }
+ return InvalidArgument("Expected equality of these values:\n %s\n %s",
+ StrCat(lhs).c_str(), StrCat(rhs).c_str());
+}
+
+// Specializations for floating types that do bitwise comparisons when equality
+// comparison is requested.
+template <>
+Status CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
+ return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
+}
+template <>
+Status CompareEqual<Eigen::half>(Eigen::half lhs, Eigen::half rhs) {
+ return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
+}
+template <>
+Status CompareEqual<float>(float lhs, float rhs) {
+ return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
+}
+template <>
+Status CompareEqual<double>(double lhs, double rhs) {
+ return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
+}
+template <>
+Status CompareEqual<complex64>(complex64 lhs, complex64 rhs) {
+ auto res = CompareEqual<float>(lhs.real(), rhs.real());
+ if (!res.ok()) {
+ return res;
+ }
+ return CompareEqual<float>(lhs.imag(), rhs.imag());
+}
+
+// A recursive function which iterates through every index of expected and
+// actual literal and compares their values elementwise. Returns true if all
+// elements are equal.
+template <typename NativeT>
+Status Equal(LiteralSlice expected, LiteralSlice actual,
+ tensorflow::gtl::MutableArraySlice<int64> multi_index,
+ int64 dimension) {
+ if (dimension == expected.shape().dimensions_size()) {
+ NativeT expected_value = expected.Get<NativeT>(multi_index);
+ NativeT actual_value = actual.Get<NativeT>(multi_index);
+ return CompareEqual<NativeT>(expected_value, actual_value);
+ }
+
+ Status result;
+ for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
+ multi_index[dimension] = i;
+ result.Update(Equal<NativeT>(expected, actual, multi_index, dimension + 1));
+ }
+ return result;
+}
+
+} // namespace
+
+Status EqualShapes(const Shape& expected, const Shape& actual) {
+ if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
+ return InvalidArgument("tupleness-mismatch! want: %s got %s",
+ ShapeUtil::HumanString(expected).c_str(),
+ ShapeUtil::HumanString(actual).c_str());
+ }
+ if (ShapeUtil::IsTuple(expected)) {
+ if (ShapeUtil::TupleElementCount(expected) !=
+ ShapeUtil::TupleElementCount(actual)) {
+ return InvalidArgument(
+ "want tuple element count: %lld got tuple element count: %lld",
+ ShapeUtil::TupleElementCount(expected),
+ ShapeUtil::TupleElementCount(actual));
+ }
+ for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
+ Status result =
+ EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
+ if (!result.ok()) {
+ return AppendStatus(result, StrCat("mismatch in tuple index", i));
+ }
+ }
+ } else {
+ if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
+ return InvalidArgument("want rank of %s got rank of %s",
+ ShapeUtil::HumanString(expected).c_str(),
+ ShapeUtil::HumanString(actual).c_str());
+ }
+ if (expected.element_type() != actual.element_type()) {
+ return InvalidArgument(
+ "mismatch in primitive type %s vs %s",
+ PrimitiveType_Name(expected.element_type()).c_str(),
+ PrimitiveType_Name(actual.element_type()).c_str());
+ }
+ if (expected.dimensions_size() != actual.dimensions_size()) {
+ return InvalidArgument("want dimensions_size %d got dimensions_size %d",
+ expected.dimensions_size(),
+ actual.dimensions_size());
+ }
+ for (int i = 0; i < expected.dimensions_size(); ++i) {
+ if (expected.dimensions(i) != actual.dimensions(i)) {
+ return InvalidArgument(
+ "mismatch in dimension #%d expected: %s actual: %s", i,
+ ShapeUtil::HumanString(expected).c_str(),
+ ShapeUtil::HumanString(actual).c_str());
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status Equal(const LiteralSlice& expected, const LiteralSlice& actual) {
+ VLOG(1) << "expected:";
+ XLA_VLOG_LINES(1, expected.ToString());
+ VLOG(1) << "actual:";
+ XLA_VLOG_LINES(1, actual.ToString());
+
+ TF_RETURN_IF_ERROR(EqualShapes(expected.shape(), actual.shape()));
+ std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ Status result;
+ switch (expected.shape().element_type()) {
+ case PRED:
+ result = Equal<bool>(expected, actual, &multi_index, 0);
+ break;
+ case U8:
+ result = Equal<uint8>(expected, actual, &multi_index, 0);
+ break;
+ case S32:
+ result = Equal<int32>(expected, actual, &multi_index, 0);
+ break;
+ case S64:
+ result = Equal<int64>(expected, actual, &multi_index, 0);
+ break;
+ case U32:
+ result = Equal<uint32>(expected, actual, &multi_index, 0);
+ break;
+ case U64:
+ result = Equal<uint64>(expected, actual, &multi_index, 0);
+ break;
+ case BF16:
+ result = Equal<bfloat16>(expected, actual, &multi_index, 0);
+ break;
+ case F16:
+ result = Equal<half>(expected, actual, &multi_index, 0);
+ break;
+ case F32:
+ result = Equal<float>(expected, actual, &multi_index, 0);
+ break;
+ case F64:
+ result = Equal<double>(expected, actual, &multi_index, 0);
+ break;
+ case C64:
+ result = Equal<complex64>(expected, actual, &multi_index, 0);
+ break;
+ case TUPLE: {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
+ result.Update(
+ Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i})));
+ }
+ break;
+ }
+ default:
+ LOG(FATAL)
+ << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
+ << PrimitiveType_Name(expected.shape().element_type());
+ }
+
+ if (result.ok()) {
+ return Status::OK();
+ }
+
+ return AppendStatus(result,
+ tensorflow::strings::Printf("expected: %s\nactual: %s",
+ expected.ToString().c_str(),
+ actual.ToString().c_str()));
+}
+
+} // namespace literal_comparison
+} // namespace xla
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Library for comparing literals without taking a dependency on testing
+// libraries.
+
+#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
+#define TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+namespace literal_comparison {
+
+// Returns ok if the given shapes have the same rank, dimension sizes, and
+// primitive types.
+Status EqualShapes(const Shape& expected, const Shape& actual);
+
+// Returns ok if the expected and actual literals are (bitwise) equal for all
+// elements in the literal. Also, asserts that the rank, dimensions sizes, and
+// primitive type are equal.
+Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
+
+} // namespace literal_comparison
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LITERAL_COMPARISON_H_
}
}
+// Return a literal with all arrays of type FromNativeT converted to type
+// ToNativeT in the given literal.
+template <typename FromNativeT, typename ToNativeT>
+std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
+ // First construct shape of the result.
+ Shape result_shape(literal.shape());
+ ShapeUtil::ForEachMutableSubshape(
+ &result_shape, [](Shape* subshape, const ShapeIndex&) {
+ if (subshape->element_type() ==
+ primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+ subshape->set_element_type(
+ primitive_util::NativeToPrimitiveType<ToNativeT>());
+ }
+ });
+ auto result = MakeUnique<Literal>(result_shape);
+
+ // Then copy over the data from 'literal' converting FromNativeT values to
+ // ToNativeT values as necessary.
+ ShapeUtil::ForEachSubshape(
+ literal.shape(),
+ [&](const Shape& subshape, const ShapeIndex& shape_index) {
+ if (ShapeUtil::IsArray(subshape)) {
+ if (subshape.element_type() ==
+ primitive_util::NativeToPrimitiveType<FromNativeT>()) {
+ auto src = literal.data<FromNativeT>(shape_index);
+ auto dest = result->data<ToNativeT>(shape_index);
+ for (int64 i = 0; i < src.size(); ++i) {
+ dest[i] = static_cast<ToNativeT>(src[i]);
+ }
+ } else {
+ TF_CHECK_OK(result->CopyFrom(literal,
+ /*dest_shape_index=*/shape_index,
+ /*src_shape_index=*/shape_index));
+ }
+ }
+ });
+ return result;
+}
+
} // namespace
LiteralBase::~LiteralBase() {}
return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions));
}
+/* static */ std::unique_ptr<Literal> Literal::ConvertBF16ToF32(
+ const LiteralSlice& bf16_literal) {
+ return ConvertType<bfloat16, float>(bf16_literal);
+}
+
+/* static */ std::unique_ptr<Literal> Literal::ConvertF32ToBF16(
+ const LiteralSlice& f32_literal) {
+ return ConvertType<float, bfloat16>(f32_literal);
+}
+
template <typename NativeT>
Status Literal::CopySliceFromInternal(
const LiteralBase& src_literal, tensorflow::gtl::ArraySlice<int64> src_base,
return std::move(output);
}
+/* static */ std::unique_ptr<Literal> Literal::ReshapeSlice(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major,
+ const LiteralSlice& literal) {
+ int64 new_num_elements = 1;
+ for (int64 i = 0; i < new_dimensions.size(); ++i) {
+ new_num_elements *= new_dimensions[i];
+ }
+ CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
+ CHECK_EQ(new_dimensions.size(), minor_to_major.size());
+
+ auto new_literal = MakeUnique<Literal>(
+ ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
+
+ // Create a new shape with the given minor-to-major layout. This shape is used
+ // solely for converting linear address to multi-dimensional addresses when
+ // writing elements to the new literal.
+ Shape shape_with_layout = new_literal->shape();
+ *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+
+ // Copy data into new literal, element-by-element.
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
+ std::vector<int64> from_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
+ std::vector<int64> to_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
+ switch (literal.shape().element_type()) {
+ case PRED:
+ new_literal->Set<bool>(to_multi_index,
+ literal.Get<bool>(from_multi_index));
+ break;
+ case U8:
+ new_literal->Set<uint8>(to_multi_index,
+ literal.Get<uint8>(from_multi_index));
+ break;
+ case U32:
+ new_literal->Set<uint32>(to_multi_index,
+ literal.Get<uint32>(from_multi_index));
+ break;
+ case S32:
+ new_literal->Set<int32>(to_multi_index,
+ literal.Get<int32>(from_multi_index));
+ break;
+ case U64:
+ new_literal->Set<uint64>(to_multi_index,
+ literal.Get<uint64>(from_multi_index));
+ break;
+ case S64:
+ new_literal->Set<int64>(to_multi_index,
+ literal.Get<int64>(from_multi_index));
+ break;
+ case F32:
+ new_literal->Set<float>(to_multi_index,
+ literal.Get<float>(from_multi_index));
+ break;
+ case F64:
+ new_literal->Set<double>(to_multi_index,
+ literal.Get<double>(from_multi_index));
+ break;
+ case C64:
+ new_literal->Set<complex64>(to_multi_index,
+ literal.Get<complex64>(from_multi_index));
+ break;
+ default:
+ LOG(FATAL) << "Unhandled primitive element type: "
+ << PrimitiveType_Name(literal.shape().element_type());
+ }
+ }
+
+ return new_literal;
+}
+
std::unique_ptr<Literal> LiteralBase::Transpose(
tensorflow::gtl::ArraySlice<int64> permutation) const {
CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose";
return std::move(literal);
}
+/* static */ string Literal::MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
+}
+
const void* LiteralBase::untyped_data(const ShapeIndex& shape_index) const {
return piece(shape_index).untyped_data();
}
PrimitiveType primitive_type,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // If the given literal's data type is bfloat16, converts it to a float
+ // literal; otherwise, returns a copy of it. If the literal is a tuple,
+ // recursively converts its elements.
+ static std::unique_ptr<Literal> ConvertBF16ToF32(
+ const LiteralSlice& bf16_literal);
+
+ // If the given literal's data type is float, converts it to a bfloat16
+ // literal; otherwise, returns a copy of it. If the literal is a tuple,
+ // recursively converts its elements.
+ static std::unique_ptr<Literal> ConvertF32ToBF16(
+ const LiteralSlice& f32_literal);
+
+ // Creates a literal with a new shape with the given new dimensions using the
+ // data in the given input literal. For reshaping purposes the (flat) data
+ // buffer of the input literal is assumed to have the given minor_to_major
+ // layout order.
+ static std::unique_ptr<Literal> ReshapeSlice(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major,
+ const LiteralSlice& literal);
+
+ // Creates a literal with the supplied shape, and uses the provided value
+ // generator to populate the literal's values.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape,
+ const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
+
+ // Creates a literal with the supplied shape, and initializes the literal
+ // values using a normal distribution with given mean and stddev standard
+ // deviation, and using the engine as entropy generator.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type, typename E,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev);
+
+ // Creates a literal with the supplied shape, and initializes the literal
+ // values using a normal distribution with given mean and stddev standard
+ // deviation.
+ // Returns the new literal object, or an error Status if failed.
+ template <
+ PrimitiveType type,
+ typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
+ static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev);
+
//
// End of factory methods.
+ // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
+ // be returned for a 2-dimensional index with dimension 0 index equal to 7,
+ // dimension 1 equal to 8.
+ static string MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index);
+
protected:
// Recursively sets the subshapes and buffers of all subpieces rooted at
// 'piece'. If 'allocate_array' is true, memory is allocated for the arrays in
return literal;
}
+template <PrimitiveType type, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+ const Shape& shape,
+ const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
+ TF_RET_CHECK(shape.element_type() == type);
+ std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
+ TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
+ [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ return generator(indexes);
+ }));
+ return std::move(literal);
+}
+
+template <PrimitiveType type, typename E, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+ const Shape& shape, E* engine, T mean, T stddev) {
+ using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
+ std::normal_distribution<NativeT> generator(mean, stddev);
+ return CreateRandomLiteral<type, NativeT>(
+ shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
+ return generator(*engine);
+ });
+}
+
+template <PrimitiveType type, typename T>
+/* static */ StatusOr<std::unique_ptr<Literal>> Literal::CreateRandomLiteral(
+ const Shape& shape, T mean, T stddev) {
+ std::minstd_rand0 engine;
+ return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
+}
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build());
TF_ASSERT_OK_AND_ASSIGN(auto result_literal, client_->ExecuteAndTransfer(
computation, {}, nullptr));
- LiteralTestUtil::ExpectNear(*expected_literal, *result_literal,
- ErrorSpec(0.0001));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_literal, *result_literal,
+ ErrorSpec(0.0001)));
}
} // namespace
EXPECT_TRUE(OutputsBF16(dot->operand(1)));
EXPECT_EQ(dot->operand(0)->opcode(), HloOpcode::kConstant);
EXPECT_EQ(dot->operand(1)->opcode(), HloOpcode::kConstant);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
dot->operand(0)->literal(),
- *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_a)));
- LiteralTestUtil::ExpectEqual(
+ *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_a))));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
dot->operand(1)->literal(),
- *LiteralTestUtil::ConvertF32ToBF16(*Literal::CreateFromArray(array_b)));
+ *Literal::ConvertF32ToBF16(*Literal::CreateFromArray(array_b))));
}
// Tests that BF16 can be propagated through nested tuples.
const int64 slice_limits[] = {10, 8, 6, 5, 9};
const int64 slice_strides[] = {1, 1, 1, 1, 1};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- LiteralTestUtil::CreateRandomLiteral<F32>(
+ Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
HloInstruction* literal_instruction = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
HloComputation::Builder builder(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- LiteralTestUtil::CreateRandomLiteral<F32>(
+ Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->Literal::CloneToUnique();
HloInstruction* literal_instruction = builder.AddInstruction(
auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR0<float>(84.0);
- LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
auto result = ExecuteAndTransfer(std::move(module), {});
auto expected = Literal::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
- LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(1e-4)));
}
TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
auto element_type = expected->shape().element_type();
if (element_type == F32 || element_type == F64) {
ErrorSpec error(aabs);
- LiteralTestUtil::ExpectNear(*expected, *result, error);
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, error));
} else {
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
}
std::unique_ptr<Literal> result = Evaluate();
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
bool use_bfloat16_;
auto expected = Literal::CreateR2<float>({{0, 4}, {2, 4}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) {
auto expected = Literal::CreateR2<float>({{0, 0}, {1, 1}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs select
auto expected = Literal::CreateR2<float>({{2, 5}, {0, 4}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
auto expected = Literal::CreateR2<int64>({{4, -16}, {-196, 12}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
// Verifies Reshape operation is correctly evaluated.
HloComputation::Builder b(TestName());
const int64 dimensions[] = {11, 8, 7, 5, 9};
TF_ASSERT_OK_AND_ASSIGN(auto literal,
- LiteralTestUtil::CreateRandomLiteral<F32>(
+ Literal::CreateRandomLiteral<F32>(
ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
auto literal_clone = literal->CloneToUnique();
HloInstruction* literal_instruction =
std::unique_ptr<Literal> result = Evaluate({});
- LiteralTestUtil::ExpectEqual(*result, *output_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
}
TEST_P(HloEvaluatorTest, DoesBroadcastScalar) {
std::unique_ptr<Literal> result = Evaluate({});
- LiteralTestUtil::ExpectEqual(*result, *output_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *output_literal));
}
TEST_P(HloEvaluatorTest, DoesConcatenateSimple) {
auto expected =
Literal::CreateR2<int64>({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) {
std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR1<int64>({100, 200});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ConvertWithSameLayout) {
std::unique_ptr<Literal> result = Evaluate();
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) {
std::unique_ptr<Literal> result = Evaluate();
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
PaddingConfig CreatePaddingConfig(
auto expected = Literal::CreateR2<int32>(
{{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) {
auto expected = Literal::CreateR4FromArray4D<float>(*expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, NegativePadding2D) {
(*expected_array)(0, 4) = 2.718f;
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
- LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *result, ErrorSpec(0x1.0P-5)));
}
TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) {
auto expected_array = MakeUnique<Array2D<float>>(0, 9);
auto expected = Literal::CreateR2FromArray2D<float>(*expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank1) {
// clang-format on
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DotRank1AndRank2) {
auto expected = Literal::CreateR1<float>({22.f, 28.f});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DotRank2AndRank2) {
});
auto expected = Literal::CreateR2FromArray2D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, SimpleConv1D) {
Array3D<float> expected_array = {{{11.f, 18.f, 9.f}}};
auto expected = Literal::CreateR3FromArray3D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) {
// clang-format on
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) {
auto expected = Literal::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) {
auto expected = Literal::CreateR4FromArray4D<float>(
use_bfloat16_ ? expected_array_bf16 : expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) {
}));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) {
}));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest,
}));
auto expected = Literal::CreateR4FromArray4D<float>(expected_array);
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
class HloEvaluatorPreciseReduceTest : public HloVerifiedTestBase {};
auto expected = Literal::CreateR1<float>({6, 18});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ReduceWindowMax) {
std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR2<float>({{6, 7}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
std::unique_ptr<Literal> result = Evaluate();
auto expected = Literal::CreateR2<float>({{1, 3, 5}, {5, 11, 13}});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) {
std::vector<int64> output_dims = {4, 3, 3, 3, 4, 4};
std::unique_ptr<Literal> result_literal =
Literal::CreateFullWithDescendingLayout<float>(output_dims, 8.0f);
- LiteralTestUtil::ExpectEqual(*result_literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result_literal, *result));
}
TEST_P(HloEvaluatorTest, StridedSlice) {
{19},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DynamicSlice) {
{6, 7, 8},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
// Verifies that the HloEvaluator's implementation goes along with existing
{6, 7, 8},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, DynamicSliceUpdate) {
{5, -6, -7},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, SetAndGetTuples) {
{5, 6, 7},
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) {
result_inner_literal.get(),
});
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, Reverse) {
});
// clang-format on
- LiteralTestUtil::ExpectEqual(*expected, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *result));
}
TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) {
add, {{param0, Literal::CreateR1<float>({1, 2, 3, 4}).get()},
{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status());
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}),
- *result.ValueOrDie());
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
}
// Check that EvaluateWithSubstitutions works if one of the operands to the op
auto result = evaluator.EvaluateWithSubstitutions(
add, {{square, Literal::CreateR1<float>({10, 20, 30, 40}).get()}});
TF_ASSERT_OK(result.status());
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<float>({11, 22, 33, 44}),
- *result.ValueOrDie());
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *Literal::CreateR1<float>({11, 22, 33, 44}), *result.ValueOrDie()));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV1) {
std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherV2) {
std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 3}, {4, 6}, {7, 9}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherMultipleBatchDims) {
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 2}, {2, 1}});
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR3<int32>(
{{{1, 3}, {4, 6}, {7, 9}}, {{3, 2}, {6, 5}, {9, 8}}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_TensorFlowGatherNd) {
{{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 0}, {1, 0}});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, 1}, {-4, 4}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest,
{{-7, 7}, {-8, 8}, {-9, 9}}});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{0, 0}, {1, 0}});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-2, 2}, {-1, 1}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_DynamicSlice) {
std::unique_ptr<Literal> operand =
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({1, 1});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{5}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{5}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_BatchDynamicSlice) {
Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR2<int32>({{2, 1}, {1, 1}});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR3<int32>({{{8}}, {{5}}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{8}}, {{5}}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_ZeroDimBounds) {
ParseAndVerifyModule(hlo_text);
std::unique_ptr<Literal> operand = Literal::CreateR2<int32>({{}, {}, {}});
std::unique_ptr<Literal> gather_indices = Literal::CreateR1<int32>({0, 2});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{}, {}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{}, {}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
TEST_P(HloEvaluatorTest, EvaluateGather_NoOutputWindowDims) {
std::unique_ptr<Literal> operand = Literal::CreateR1<int32>({0, 1, 2});
std::unique_ptr<Literal> gather_indices =
Literal::CreateR3<int32>({{{0}, {1}}, {{2}, {1}}});
- LiteralTestUtil::ExpectEqual(
- *Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
- *Evaluate({operand.get(), gather_indices.get()}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{0, 1}, {2, 1}}),
+ *Evaluate({operand.get(), gather_indices.get()})));
}
// Verifies that HloEvaluator evaluates a HLO instruction that performs
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR1<float>({4, 3, 3, 4});
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
// Test that `constant` function is changed to `broadcast`.
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
TEST_F(InlinerTest, MapSubtractOppositeOrder) {
// Verify execution on CPU.
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
auto expected = Literal::CreateR1<float>({3, 1, -1, -3});
- LiteralTestUtil::ExpectEqual(*result, *expected);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result, *expected));
}
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_comparison",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(42.0), *result,
- error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR0<float>(42.0), *result,
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
- error_spec_);
+ error_spec_));
}
XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
- LiteralSlice(*result, {0}), error_spec_);
+ LiteralSlice(*result, {0}), error_spec_));
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
- LiteralSlice(*result, {1}), error_spec_);
+ LiteralSlice(*result, {1}), error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
- *Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
- error_spec_);
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}),
+ *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
- *Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
- error_spec_);
+ EXPECT_TRUE(
+ LiteralTestUtil::Near(*Literal::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}),
+ *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
{{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
- *result, error_spec_);
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
Array2D<float> pz({{1, 2}, {1, 2}});
expected.FillWithPZ(pz);
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
}
expected.FillWithYX(yx);
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result,
- error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(*Literal::CreateR4FromArray4D(r4_array),
+ *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
Array4D<float> expected(64, 64, 3, 3);
expected.Fill(1.0f);
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
Array4D<float> expected(3, 3, 2, 2);
expected.FillWithYX(to_broadcast);
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
hlo_module->AddEntryComputation(builder.Build());
auto result = ExecuteAndTransfer(std::move(hlo_module), {});
- LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D<float>(expected),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR4FromArray4D<float>(expected), *result, error_spec_));
}
} // namespace
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
+ converted_expected = Literal::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
}
}
auto expect_equal = [&](const Literal& actual, const string& error_message) {
- LiteralTestUtil::ExpectEqual(*expected_ptr, actual, error_message);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, actual)) << error_message;
};
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
return ComputeAndCompareLiteralWithAllOutputLayouts(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- LiteralTestUtil::ExpectEqual(*expected_ptr, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_ptr, *actual));
return tensorflow::Status::OK();
}
std::unique_ptr<Literal> converted_expected;
Shape layout_shape;
if (use_bfloat16_) {
- converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected);
+ converted_expected = Literal::ConvertF32ToBF16(expected);
expected_ptr = converted_expected.get();
if (shape_with_layout != nullptr) {
layout_shape = *shape_with_layout;
}
}
auto expect_near = [&](const Literal& actual, const string& error_message) {
- LiteralTestUtil::ExpectNear(*expected_ptr, actual, error, error_message);
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, actual, error))
+ << error_message;
};
if (execution_options_.debug_options().xla_test_all_output_layouts()) {
return ComputeAndCompareLiteralWithAllOutputLayouts(
}
TF_ASSIGN_OR_RETURN(auto actual, ExecuteAndTransfer(computation, arguments,
shape_with_layout));
- LiteralTestUtil::ExpectNear(*expected_ptr, *actual, error);
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected_ptr, *actual, error));
return tensorflow::Status::OK();
}
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(expected, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, *actual));
}
void ClientLibraryTestBase::ComputeAndCompareTuple(
return;
}
auto actual = actual_status.ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(expected, *actual, error);
+ EXPECT_TRUE(LiteralTestUtil::Near(expected, *actual, error));
}
void ClientLibraryTestBase::ComputeAndCompare(
}
std::unique_ptr<Literal> reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(*reference, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*reference, *result));
}
void ClientLibraryTestBase::ComputeAndCompare(
}
std::unique_ptr<Literal> reference, result;
std::tie(reference, result) = status_or_data.ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(*reference, *result, error);
+ EXPECT_TRUE(LiteralTestUtil::Near(*reference, *result, error));
}
StatusOr<std::pair<std::unique_ptr<Literal>, std::unique_ptr<Literal>>>
XlaOp ClientLibraryTestBase::CreateConstantFromLiteral(const Literal& literal,
XlaBuilder* builder) {
return builder->ConstantLiteral(
- use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal);
+ use_bfloat16_ ? *Literal::ConvertF32ToBF16(literal) : literal);
}
std::unique_ptr<GlobalData>
const Literal* param_literal = &literal;
std::unique_ptr<Literal> converted_literal;
if (use_bfloat16_) {
- converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal);
+ converted_literal = Literal::ConvertF32ToBF16(literal);
param_literal = converted_literal.get();
}
std::unique_ptr<GlobalData> data =
XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR0(value);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ literal = Literal::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR1(values);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ literal = Literal::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR2FromArray2D(array_2d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ literal = Literal::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
const string& name, XlaBuilder* builder, XlaOp* data_handle) {
std::unique_ptr<Literal> literal = Literal::CreateR3FromArray3D(array_3d);
if (use_bfloat16_ && literal->shape().element_type() == F32) {
- literal = LiteralTestUtil::ConvertF32ToBF16(*literal);
+ literal = Literal::ConvertF32ToBF16(*literal);
}
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
TF_ASSERT_OK_AND_ASSIGN(
auto computed, client_->Transfer(*data, &expected_literal->shape()));
- LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
- computed->shape());
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
+ expected_literal->shape(), computed->shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
}
auto result_literal,
client_->Transfer(*results[0], &expected_result->shape()));
- LiteralTestUtil::ExpectEqual(*expected_result, *result_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_result, *result_literal));
}
} // namespace
/*execution_options=*/&execution_options_,
&execution_profile)
.ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(*Literal::CreateR0<float>(expected_result),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR0<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
.ConsumeValueOrDie();
std::unique_ptr<Literal> result =
client_->Transfer(*data_handle).ConsumeValueOrDie();
- LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>(expected_result),
- *result, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR2<float>(expected_result), *result, error_spec_));
EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
}
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal =
Literal::CreateR1<int32>({4, 6});
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
TF_ASSERT_OK_AND_ASSIGN(auto computed,
ComputeConstantLiteral(client, computation, &b));
std::unique_ptr<Literal> expected_literal = Literal::CreateR0<int32>(5);
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
std::unique_ptr<Literal> expected_literal =
Literal::CreateR2WithLayout<int32>({{11, 22}, {33, 44}},
LayoutUtil::MakeLayout(layout));
- LiteralTestUtil::AssertEqualShapesAndLayouts(expected_literal->shape(),
- computed->shape());
- LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ ASSERT_TRUE(LiteralTestUtil::EqualShapesAndLayouts(
+ expected_literal->shape(), computed->shape()));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *computed));
}
}
}
module->AddEntryComputation(std::move(computation));
std::unique_ptr<Literal> result = ExecuteAndTransfer(std::move(module), {});
- LiteralTestUtil::ExpectEqual(literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(literal, *result));
}
void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
auto actual = ExecuteAndTransfer(&builder, {input_data.get()}, &out_shape)
.ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(*empty, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*empty, *actual));
}
} // namespace
auto expected = Literal::CreateR2FromArray2D(answer_data);
auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
if (primitive_util::IsFloatingPointType(prim_type)) {
- LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(*expected, *actual, ErrorSpec(1e-4)));
} else {
- LiteralTestUtil::ExpectEqual(*expected, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
}
}
const4, reshape3, add2, const1, const0},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{0.5}, {2.72}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}),
- ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR2<float>({{0.5}, {2.72}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
// Test whether we emit appropriate code for parameters of fusion instructions.
->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectNear(*Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}),
- ErrorSpec(1e-4));
+ EXPECT_TRUE(LiteralTestUtil::Near(
+ *Literal::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, RandomizedParallelPartition) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectNear(
+ EXPECT_TRUE(LiteralTestUtil::Near(
*Literal::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
+ *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)));
}
XLA_TEST_F(FusionTest, ReshapeToScalar) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(5),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(5),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__1by1by1) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR3<int32>({{{7}}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR3<int32>({{{7}}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape__) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(7),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(7),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_2by3) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Transpose_3by3) {
hlo_module->AddEntryComputation(builder.Build())
->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Reverse) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({3, 2, 1}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({3, 2, 1}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReverseNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reverse1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-3, -2, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-3, -2, -1}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, BroadcastNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, broadcast1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -1}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -1}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, SliceNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, slice1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-1, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-1, -3}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DynamicSliceNegate) {
/*instructions_to_fuse=*/{negate3, dynamic_slice2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({-2, -3}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({-2, -3}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, ReshapeNegate) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, reshape1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -2}, {-3, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
// TODO(b/64070202): Investigate failure.
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate2, transpose1},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR2<int32>({{-1, -3}, {-2, -4}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
std::unique_ptr<HloComputation> MakeReduceTestComputation() {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(15),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(15),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR0<int32>(-15),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR0<int32>(-15),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
HloInstruction::FusionKind::kLoop);
- LiteralTestUtil::ExpectEqual(
+ EXPECT_TRUE(LiteralTestUtil::Equal(
*Literal::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
// When a constant (or other op) which has multiple users is imported
// fused instruction contains the constant(2), the parameter, and 4 adds
EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6);
- LiteralTestUtil::ExpectEqual(*Literal::CreateR1<int32>({8}),
- *ExecuteAndTransfer(std::move(hlo_module), {}));
+ EXPECT_TRUE(
+ LiteralTestUtil::Equal(*Literal::CreateR1<int32>({8}),
+ *ExecuteAndTransfer(std::move(hlo_module), {})));
}
XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
client_->ExecuteParallel(computation_instances));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Literal> result_literal,
client_->Transfer(*(result_data[0])));
- LiteralTestUtil::ExpectEqual(
- *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}}));
+ EXPECT_TRUE(LiteralTestUtil::Equal(
+ *result_literal, *Literal::CreateR2<int32>({{1, 2, 3}, {7, 8, 9}})));
}
} // namespace
} // namespace xla
#include "tensorflow/compiler/xla/index_util.h"
#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_comparison.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes(
const Shape& expected, const Shape& actual) {
- if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
- return ::testing::AssertionFailure()
- << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected)
- << " got: " << ShapeUtil::HumanString(actual);
- }
- if (ShapeUtil::IsTuple(expected)) {
- if (ShapeUtil::TupleElementCount(expected) !=
- ShapeUtil::TupleElementCount(actual)) {
- return ::testing::AssertionFailure()
- << "want tuple element count: "
- << ShapeUtil::TupleElementCount(expected)
- << " got tuple element count: "
- << ShapeUtil::TupleElementCount(actual);
- }
- for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
- ::testing::AssertionResult result =
- EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i))
- << "mismatch in tuple index " << i;
- if (!result) {
- return result;
- }
- }
- } else {
- if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
- return ::testing::AssertionFailure()
- << "want rank of: " << ShapeUtil::HumanString(expected)
- << " got rank of: " << ShapeUtil::HumanString(actual);
- }
- if (expected.element_type() != actual.element_type()) {
- return ::testing::AssertionFailure()
- << PrimitiveType_Name(expected.element_type()) << " vs "
- << PrimitiveType_Name(actual.element_type());
- }
- if (expected.dimensions_size() != actual.dimensions_size()) {
- return ::testing::AssertionFailure()
- << "want dimensions_size " << expected.dimensions_size()
- << " got dimensions_size " << actual.dimensions_size();
- }
- for (int i = 0; i < expected.dimensions_size(); ++i) {
- if (expected.dimensions(i) != actual.dimensions(i)) {
- return ::testing::AssertionFailure()
- << "mismatch in dimension #" << i
- << " expected: " << ShapeUtil::HumanString(expected)
- << " actual: " << ShapeUtil::HumanString(actual);
- }
- }
+ Status result = literal_comparison::EqualShapes(expected, actual);
+ if (result.ok()) {
+ return ::testing::AssertionSuccess();
}
- return ::testing::AssertionSuccess();
-}
-
-/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
- const Shape& actual) {
- ASSERT_TRUE(EqualShapes(expected, actual));
+ return ::testing::AssertionFailure() << result;
}
-/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
+/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapesAndLayouts(
const Shape& expected, const Shape& actual) {
- ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
-}
-
-namespace {
-
-// Return a literal with all arrays of type FromNativeT converted to type
-// ToNativeT in the given literal.
-template <typename FromNativeT, typename ToNativeT>
-std::unique_ptr<Literal> ConvertType(LiteralSlice literal) {
- // First construct shape of the result.
- Shape result_shape(literal.shape());
- ShapeUtil::ForEachMutableSubshape(
- &result_shape, [](Shape* subshape, const ShapeIndex&) {
- if (subshape->element_type() ==
- primitive_util::NativeToPrimitiveType<FromNativeT>()) {
- subshape->set_element_type(
- primitive_util::NativeToPrimitiveType<ToNativeT>());
- }
- });
- auto result = MakeUnique<Literal>(result_shape);
-
- // Then copy over the data from 'literal' converting FromNativeT values to
- // ToNativeT values as necessary.
- ShapeUtil::ForEachSubshape(
- literal.shape(),
- [&](const Shape& subshape, const ShapeIndex& shape_index) {
- if (ShapeUtil::IsArray(subshape)) {
- if (subshape.element_type() ==
- primitive_util::NativeToPrimitiveType<FromNativeT>()) {
- auto src = literal.data<FromNativeT>(shape_index);
- auto dest = result->data<ToNativeT>(shape_index);
- for (int64 i = 0; i < src.size(); ++i) {
- dest[i] = static_cast<ToNativeT>(src[i]);
- }
- } else {
- TF_CHECK_OK(result->CopyFrom(literal,
- /*dest_shape_index=*/shape_index,
- /*src_shape_index=*/shape_index));
- }
- }
- });
- return result;
-}
-
-} // namespace
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertBF16ToF32(
- LiteralSlice literal) {
- return ConvertType<bfloat16, float>(literal);
-}
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::ConvertF32ToBF16(
- LiteralSlice literal) {
- return ConvertType<float, bfloat16>(literal);
+ if (expected.ShortDebugString() != actual.ShortDebugString()) {
+ return ::testing::AssertionFailure()
+ << "want: " << expected.ShortDebugString()
+ << " got: " << actual.ShortDebugString();
+ }
+ return ::testing::AssertionSuccess();
}
namespace {
return string(hostname);
}
-// Helper function for comparing a floating point type, FloatT, bitwise equal
-// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
-// -- on miscompare, a nice error message is given in the AssertionFailure.
-template <typename FloatT, typename UnsignedT>
-::testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
- auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
- auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
- auto lhs_double = static_cast<double>(lhs);
- auto rhs_double = static_cast<double>(rhs);
- if (ulhs != urhs) {
- return ::testing::AssertionFailure() << Printf(
- "floating values are not bitwise-equal; and equality testing "
- "was requested: %s=%g=%a vs %s=%g=%a",
- StrCat(tensorflow::strings::Hex(ulhs)).c_str(), lhs_double,
- lhs_double, StrCat(tensorflow::strings::Hex(urhs)).c_str(),
- rhs_double, rhs_double);
- }
- return ::testing::AssertionSuccess();
-}
-
-// Templated comparator that specializes for float equality comparison with the
-// bitwise helper above (this is the un-specialized fallback, to just use the
-// default gunit implementation).
-template <typename NativeT>
-::testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) {
- if (lhs == rhs) {
- return ::testing::AssertionSuccess();
- }
- ::testing::Message msg;
- msg << "Expected equality of these values:";
- msg << "\n " << lhs;
- msg << "\n " << rhs;
-
- return ::testing::AssertionFailure() << msg;
-}
-
-// Specializations for floating types that do bitwise comparisons when equality
-// comparison is requested.
-template <>
-::testing::AssertionResult CompareEqual<bfloat16>(bfloat16 lhs, bfloat16 rhs) {
- return CompareFloatsBitwiseEqual<bfloat16, uint16>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<Eigen::half>(Eigen::half lhs,
- Eigen::half rhs) {
- return CompareFloatsBitwiseEqual<Eigen::half, uint16>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
- return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
- return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
-}
-template <>
-::testing::AssertionResult CompareEqual<complex64>(complex64 lhs,
- complex64 rhs) {
- auto res = CompareEqual<float>(lhs.real(), rhs.real());
- if (!res) {
- return res;
- }
- return CompareEqual<float>(lhs.imag(), rhs.imag());
-}
-
-// A recursive function which iterates through every index of expected and
-// actual literal and compares their values elementwise. Returns true if all
-// elements are equal.
-template <typename NativeT>
-bool ExpectLiteralsEqual(LiteralSlice expected, LiteralSlice actual,
- tensorflow::gtl::MutableArraySlice<int64> multi_index,
- int64 dimension) {
- if (dimension == expected.shape().dimensions_size()) {
- NativeT expected_value = expected.Get<NativeT>(multi_index);
- NativeT actual_value = actual.Get<NativeT>(multi_index);
- ::testing::AssertionResult result =
- CompareEqual<NativeT>(expected_value, actual_value);
- return result; // Defines implicit coersion to bool.
- }
-
- bool all_match = true;
- for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
- multi_index[dimension] = i;
- all_match = all_match && ExpectLiteralsEqual<NativeT>(
- expected, actual, multi_index, dimension + 1);
- }
- return all_match;
-}
-
} // namespace
-/* static */ void LiteralTestUtil::ExpectEqual(LiteralSlice expected,
- LiteralSlice actual,
- const string& message) {
- EXPECT_TRUE(Equal(expected, actual))
- << "expected:\n"
- << expected.ToString() << "\n\tvs actual:\n"
- << actual.ToString()
- << (message.empty() ? "" : StrCat("\nmessage: ", message));
-}
-
-/* static */ void LiteralTestUtil::ExpectNotEqual(LiteralSlice expected,
- LiteralSlice actual) {
- EXPECT_FALSE(Equal(expected, actual));
-}
-
/* static */ ::testing::AssertionResult LiteralTestUtil::Equal(
- LiteralSlice expected, LiteralSlice actual) {
- VLOG(1) << "expected:";
- XLA_VLOG_LINES(1, expected.ToString());
- VLOG(1) << "actual:";
- XLA_VLOG_LINES(1, actual.ToString());
-
- AssertEqualShapes(expected.shape(), actual.shape());
- std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
- bool match = false;
- switch (expected.shape().element_type()) {
- case PRED:
- match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0);
- break;
- case U8:
- match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0);
- break;
- case S32:
- match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0);
- break;
- case S64:
- match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0);
- break;
- case U32:
- match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0);
- break;
- case U64:
- match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
- break;
- case BF16:
- match = ExpectLiteralsEqual<bfloat16>(expected, actual, &multi_index, 0);
- break;
- case F16:
- match = ExpectLiteralsEqual<half>(expected, actual, &multi_index, 0);
- break;
- case F32:
- match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
- break;
- case F64:
- match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
- break;
- case C64:
- match = ExpectLiteralsEqual<complex64>(expected, actual, &multi_index, 0);
- break;
- case TUPLE: {
- bool tuple_match = true;
- for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) {
- SCOPED_TRACE(StrCat("Tuple index ", i, " in ",
- ShapeUtil::HumanString(expected.shape())));
-
- // Create LiteralSlices of the expected and actual elements.
- auto result =
- Equal(LiteralSlice(expected, {i}), LiteralSlice(actual, {i}));
- tuple_match = tuple_match ? !!result : false;
- }
- match = tuple_match;
- break;
- }
- default:
- LOG(FATAL)
- << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
- << PrimitiveType_Name(expected.shape().element_type());
- }
- ::testing::AssertionResult result = ::testing::AssertionSuccess();
- if (!match) {
- result = ::testing::AssertionFailure()
- << "expected: " << expected.ToString()
- << "\nactual: " << actual.ToString();
- VLOG(1) << result.message();
+ const LiteralSlice& expected, const LiteralSlice& actual) {
+ Status result = literal_comparison::Equal(expected, actual);
+ if (result.ok()) {
+ return ::testing::AssertionSuccess();
}
- return result;
+ return ::testing::AssertionFailure() << result;
}
namespace {
// 3 minutes. The utility of printing a literal with >1000 elements is
// questionable, especially when writing the Literal proto to disk is orders
// of magnitude faster.
-string TruncateHugeLiteral(LiteralSlice literal) {
+string TruncateHugeLiteral(const LiteralSlice& literal) {
return RecursiveElementCount(literal.shape()) < 1000
? literal.ToString()
: "[TRUNCATED, Literal with more than 1000 values]";
// result. The assertion result is successful if all actual and expected
// elements are within the given error bound. In case of error, the assertion
// result contains a detailed error message in case of failure.
- static ::testing::AssertionResult Compare(LiteralSlice expected,
- LiteralSlice actual,
+ static ::testing::AssertionResult Compare(const LiteralSlice& expected,
+ const LiteralSlice& actual,
ErrorSpec error,
bool detailed_message) {
NearComparator<NativeT> comparator(expected, actual, error,
return Printf(
"actual %s, expected %s, index %s, rel error %8.3g, abs error %8.3g",
FpValueToString(actual).c_str(), FpValueToString(expected).c_str(),
- LiteralTestUtil::MultiIndexAsString(
+ Literal::MultiIndexAsString(
IndexUtil::LinearIndexToMultidimensionalIndex(shape,
linear_index))
.c_str(),
}
};
- explicit NearComparator(LiteralSlice expected, LiteralSlice actual,
- ErrorSpec error, bool detailed_message)
+ explicit NearComparator(const LiteralSlice& expected,
+ const LiteralSlice& actual, ErrorSpec error,
+ bool detailed_message)
: expected_(expected),
actual_(actual),
error_(error),
}
// Writes the given literal to a file in the test temporary directory.
- void WriteLiteralToTempFile(LiteralSlice literal, const string& name) {
+ void WriteLiteralToTempFile(const LiteralSlice& literal, const string& name) {
int64 now_usec = tensorflow::Env::Default()->NowMicros();
string filename = tensorflow::io::JoinPath(
tensorflow::testing::TmpDir(),
// Helper function for comparing two literals for nearness. Handles tuple-shapes
// via recursion. shape_index is the ShapeIndex of expected (or actual)
// currently being compared.
-::testing::AssertionResult NearHelper(LiteralSlice expected,
- LiteralSlice actual,
+::testing::AssertionResult NearHelper(const LiteralSlice& expected,
+ const LiteralSlice& actual,
const ErrorSpec& error,
bool detailed_message,
const ShapeIndex& shape_index) {
} // namespace
/* static */ ::testing::AssertionResult LiteralTestUtil::Near(
- LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
- bool detailed_message) {
+ const LiteralSlice& expected, const LiteralSlice& actual,
+ const ErrorSpec& error, bool detailed_message) {
return NearHelper(expected, actual, error, detailed_message,
/*shape_index=*/{});
}
-/* static */ void LiteralTestUtil::ExpectNear(LiteralSlice expected,
- LiteralSlice actual,
- const ErrorSpec& error,
- const string& message) {
- ::testing::AssertionResult res =
- Near(expected, actual, error, /*detailed_message=*/false);
- if (!res) {
- res << "Expected: " << TruncateHugeLiteral(expected) << "\n";
- res << "Actual: " << TruncateHugeLiteral(actual) << "\n";
- if (!message.empty()) {
- res << StrCat("\nmessage: ", message);
- }
- }
- EXPECT_TRUE(res);
-}
-
-/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
- LiteralSlice expected, LiteralSlice actual,
+/* static */ ::testing::AssertionResult LiteralTestUtil::NearOrEqual(
+ const LiteralSlice& expected, const LiteralSlice& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) {
if (error.has_value()) {
VLOG(1) << "Expects near";
return Equal(expected, actual);
}
-/*static*/ void LiteralTestUtil::ExpectNearOrEqual(
- LiteralSlice expected, LiteralSlice actual,
- const tensorflow::gtl::optional<ErrorSpec>& error) {
- EXPECT_TRUE(NearOrEqual(expected, actual, error));
-}
-
-/* static */ string LiteralTestUtil::MultiIndexAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index) {
- return StrCat("{", tensorflow::str_util::Join(multi_index, ","), "}");
-}
-
-/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
- tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal) {
- int64 new_num_elements = 1;
- for (int64 i = 0; i < new_dimensions.size(); ++i) {
- new_num_elements *= new_dimensions[i];
- }
- CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
- CHECK_EQ(new_dimensions.size(), minor_to_major.size());
-
- auto new_literal = MakeUnique<Literal>(
- ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
-
- // Create a new shape with the given minor-to-major layout. This shape is used
- // solely for converting linear address to multi-dimensional addresses when
- // writing elements to the new literal.
- Shape shape_with_layout = new_literal->shape();
- *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
-
- // Copy data into new literal, element-by-element.
- for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
- std::vector<int64> from_multi_index =
- IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
- std::vector<int64> to_multi_index =
- IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
- switch (literal.shape().element_type()) {
- case PRED:
- new_literal->Set<bool>(to_multi_index,
- literal.Get<bool>(from_multi_index));
- break;
- case U8:
- new_literal->Set<uint8>(to_multi_index,
- literal.Get<uint8>(from_multi_index));
- break;
- case U32:
- new_literal->Set<uint32>(to_multi_index,
- literal.Get<uint32>(from_multi_index));
- break;
- case S32:
- new_literal->Set<int32>(to_multi_index,
- literal.Get<int32>(from_multi_index));
- break;
- case U64:
- new_literal->Set<uint64>(to_multi_index,
- literal.Get<uint64>(from_multi_index));
- break;
- case S64:
- new_literal->Set<int64>(to_multi_index,
- literal.Get<int64>(from_multi_index));
- break;
- case F32:
- new_literal->Set<float>(to_multi_index,
- literal.Get<float>(from_multi_index));
- break;
- case F64:
- new_literal->Set<double>(to_multi_index,
- literal.Get<double>(from_multi_index));
- break;
- case C64:
- new_literal->Set<complex64>(to_multi_index,
- literal.Get<complex64>(from_multi_index));
- break;
- default:
- LOG(FATAL) << "Unhandled primitive element type: "
- << PrimitiveType_Name(literal.shape().element_type());
- }
- }
-
- return new_literal;
-}
-
} // namespace xla
public:
// Asserts that the given shapes have the same rank, dimension sizes, and
// primitive types.
- static ::testing::AssertionResult EqualShapes(const Shape& expected,
- const Shape& actual);
- static void AssertEqualShapes(const Shape& expected, const Shape& actual);
+ static ::testing::AssertionResult EqualShapes(
+ const Shape& expected, const Shape& actual) MUST_USE_RESULT;
// Asserts that the provided shapes are equal as defined in AssertEqualShapes
// and that they have the same layout.
- static void AssertEqualShapesAndLayouts(const Shape& expected,
- const Shape& actual);
+ static ::testing::AssertionResult EqualShapesAndLayouts(
+ const Shape& expected, const Shape& actual) MUST_USE_RESULT;
- // If the given literal's data type is bfloat16, converts it to a float
- // literal; otherwise, returns a copy of it. If the literal is a tuple,
- // recursively converts its elements.
- static std::unique_ptr<Literal> ConvertBF16ToF32(LiteralSlice bf16_literal);
-
- // If the given literal's data type is float, converts it to a bfloat16
- // literal; otherwise, returns a copy of it. If the literal is a tuple,
- // recursively converts its elements.
- static std::unique_ptr<Literal> ConvertF32ToBF16(LiteralSlice f32_literal);
-
- // Asserts that the expected and actual literals are (bitwise) equal for all
- // elements in the literal. Also, asserts that the rank, dimensions sizes, and
- // primitive type are equal.
- static ::testing::AssertionResult Equal(
- LiteralSlice expected, LiteralSlice actual) TF_MUST_USE_RESULT;
-
- // Expects that expected and actual are Equal.
- static void ExpectEqual(LiteralSlice expected, LiteralSlice actual,
- const string& message = "");
-
- // Expects that expected and actual are Not Equal.
- static void ExpectNotEqual(LiteralSlice expected, LiteralSlice actual);
+ static ::testing::AssertionResult Equal(const LiteralSlice& expected,
+ const LiteralSlice& actual)
+ TF_MUST_USE_RESULT;
// Asserts the given literal are (bitwise) equal to given expected values.
template <typename NativeT>
- static void ExpectR0Equal(NativeT expected, LiteralSlice actual);
+ static void ExpectR0Equal(NativeT expected, const LiteralSlice& actual);
+
template <typename NativeT>
static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
- LiteralSlice actual);
+ const LiteralSlice& actual);
template <typename NativeT>
static void ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
- LiteralSlice actual);
+ const LiteralSlice& actual);
+
template <typename NativeT>
static void ExpectR3Equal(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
expected,
- LiteralSlice actual);
+ const LiteralSlice& actual);
// Asserts the given literal are (bitwise) equal to given array.
template <typename NativeT>
static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
- LiteralSlice actual);
+ const LiteralSlice& actual);
template <typename NativeT>
static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
- LiteralSlice actual);
+ const LiteralSlice& actual);
template <typename NativeT>
static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
- LiteralSlice actual);
+ const LiteralSlice& actual);
// Asserts that the expected and actual literals are within the given error
// bound for all elements. Also, asserts that the rank, dimensions sizes, and
// If detailed_message is true, then the error message in the assertion result
// will contain a more detailed breakdown of mismatches.
static ::testing::AssertionResult Near(
- LiteralSlice expected, LiteralSlice actual, const ErrorSpec& error,
- bool detailed_message = false) TF_MUST_USE_RESULT;
-
- // Expects expected and actual to be Near with the given error.
- static void ExpectNear(LiteralSlice expected, LiteralSlice actual,
- const ErrorSpec& error, const string& message = "");
+ const LiteralSlice& expected, const LiteralSlice& actual,
+ const ErrorSpec& error, bool detailed_message = false) TF_MUST_USE_RESULT;
// Asserts the given literal are within the given error bound of the given
// expected values. Only supported for floating point values.
template <typename NativeT>
- static void ExpectR0Near(NativeT expected, LiteralSlice actual,
+ static void ExpectR0Near(NativeT expected, const LiteralSlice& actual,
const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
- LiteralSlice actual, const ErrorSpec& error);
+ const LiteralSlice& actual, const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
- LiteralSlice actual, const ErrorSpec& error);
+ const LiteralSlice& actual, const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR3Near(
std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>
expected,
- LiteralSlice actual, const ErrorSpec& error);
+ const LiteralSlice& actual, const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR4Near(
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
- LiteralSlice actual, const ErrorSpec& error);
+ const LiteralSlice& actual, const ErrorSpec& error);
// Asserts the given literal are within the given error bound to the given
// array. Only supported for floating point values.
template <typename NativeT>
static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
- LiteralSlice actual, const ErrorSpec& error);
+ const LiteralSlice& actual,
+ const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
- LiteralSlice actual, const ErrorSpec& error);
+ const LiteralSlice& actual,
+ const ErrorSpec& error);
+
template <typename NativeT>
static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
- LiteralSlice actual, const ErrorSpec& error);
+ const LiteralSlice& actual,
+ const ErrorSpec& error);
// If the error spec is given, returns whether the expected and the actual are
// within the error bound; otherwise, returns whether they are equal. Tuples
// will be compared recursively.
static ::testing::AssertionResult NearOrEqual(
- LiteralSlice expected, LiteralSlice actual,
+ const LiteralSlice& expected, const LiteralSlice& actual,
const tensorflow::gtl::optional<ErrorSpec>& error) TF_MUST_USE_RESULT;
- // If the error spec is given, expects the expected and the actual to be near;
- // otherwise, expects them to be equal. Tuples will be compared recursively.
- static void ExpectNearOrEqual(
- LiteralSlice expected, LiteralSlice actual,
- const tensorflow::gtl::optional<ErrorSpec>& error);
-
- // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
- // be returned for a 2-dimensional index with dimension 0 index equal to 7,
- // dimension 1 equal to 8.
- static string MultiIndexAsString(
- tensorflow::gtl::ArraySlice<int64> multi_index);
-
- // Creates a literal with a new shape with the given new dimensions using the
- // data in the given input literal. For reshaping purposes the (flat) data
- // buffer of the input literal is assumed to have the given minor_to_major
- // layout order.
- static std::unique_ptr<Literal> Reshape(
- tensorflow::gtl::ArraySlice<int64> new_dimensions,
- tensorflow::gtl::ArraySlice<int64> minor_to_major, LiteralSlice literal);
-
- // Creates a literal with the supplied shape, and uses the provided value
- // generator to populate the literal's values.
- // Returns the new literal object, or an error Status if failed.
- template <
- PrimitiveType type,
- typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape,
- const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator);
-
- // Creates a literal with the supplied shape, and initializes the literal
- // values using a normal distribution with given mean and stddev standard
- // deviation, and using the engine as entropy generator.
- // Returns the new literal object, or an error Status if failed.
- template <
- PrimitiveType type, typename E,
- typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, E* engine, T mean, T stddev);
-
- // Creates a literal with the supplied shape, and initializes the literal
- // values using a normal distribution with given mean and stddev standard
- // deviation.
- // Returns the new literal object, or an error Status if failed.
- template <
- PrimitiveType type,
- typename T = typename primitive_util::PrimitiveTypeToNative<type>::type>
- static StatusOr<std::unique_ptr<Literal>> CreateRandomLiteral(
- const Shape& shape, T mean, T stddev);
-
private:
TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
};
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
- LiteralSlice actual) {
- ExpectEqual(*Literal::CreateR0<NativeT>(expected), actual);
+ const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR0<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Equal(
- tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual) {
- ExpectEqual(*Literal::CreateR1<NativeT>(expected), actual);
+ tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR1<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Equal(
std::initializer_list<std::initializer_list<NativeT>> expected,
- LiteralSlice actual) {
- ExpectEqual(*Literal::CreateR2<NativeT>(expected), actual);
+ const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR2<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3Equal(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
- LiteralSlice actual) {
- ExpectEqual(*Literal::CreateR3<NativeT>(expected), actual);
+ const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR3<NativeT>(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
- const Array2D<NativeT>& expected, LiteralSlice actual) {
- ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual);
+ const Array2D<NativeT>& expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR2FromArray2D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
- const Array3D<NativeT>& expected, LiteralSlice actual) {
- ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual);
+ const Array3D<NativeT>& expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR3FromArray3D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
- const Array4D<NativeT>& expected, LiteralSlice actual) {
- ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual);
+ const Array4D<NativeT>& expected, const LiteralSlice& actual) {
+ EXPECT_TRUE(Equal(*Literal::CreateR4FromArray4D(expected), actual));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
- LiteralSlice actual,
+ const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR0<NativeT>(expected), actual, error);
+ EXPECT_TRUE(Near(*Literal::CreateR0<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR1Near(
- tensorflow::gtl::ArraySlice<NativeT> expected, LiteralSlice actual,
+ tensorflow::gtl::ArraySlice<NativeT> expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR1<NativeT>(expected), actual, error);
+ EXPECT_TRUE(Near(*Literal::CreateR1<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2Near(
std::initializer_list<std::initializer_list<NativeT>> expected,
- LiteralSlice actual, const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR2<NativeT>(expected), actual, error);
+ const LiteralSlice& actual, const ErrorSpec& error) {
+ EXPECT_TRUE(Near(*Literal::CreateR2<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3Near(
std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
expected,
- LiteralSlice actual, const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR3<NativeT>(expected), actual, error);
+ const LiteralSlice& actual, const ErrorSpec& error) {
+ EXPECT_TRUE(Near(*Literal::CreateR3<NativeT>(expected), actual, error));
}
template <typename NativeT>
std::initializer_list<std::initializer_list<
std::initializer_list<std::initializer_list<NativeT>>>>
expected,
- LiteralSlice actual, const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR4<NativeT>(expected), actual, error);
+ const LiteralSlice& actual, const ErrorSpec& error) {
+ EXPECT_TRUE(Near(*Literal::CreateR4<NativeT>(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
- const Array2D<NativeT>& expected, LiteralSlice actual,
+ const Array2D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error);
+ EXPECT_TRUE(Near(*Literal::CreateR2FromArray2D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
- const Array3D<NativeT>& expected, LiteralSlice actual,
+ const Array3D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error);
+ EXPECT_TRUE(Near(*Literal::CreateR3FromArray3D(expected), actual, error));
}
template <typename NativeT>
/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
- const Array4D<NativeT>& expected, LiteralSlice actual,
+ const Array4D<NativeT>& expected, const LiteralSlice& actual,
const ErrorSpec& error) {
- ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error);
-}
-
-template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(
- const Shape& shape,
- const std::function<T(tensorflow::gtl::ArraySlice<int64>)>& generator) {
- using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
- TF_RET_CHECK(shape.element_type() == type);
- std::unique_ptr<Literal> literal = Literal::CreateFromShape(shape);
- TF_RETURN_IF_ERROR(literal.get()->Populate<NativeT>(
- [&](tensorflow::gtl::ArraySlice<int64> indexes) {
- return generator(indexes);
- }));
- return std::move(literal);
-}
-
-template <PrimitiveType type, typename E, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(const Shape& shape, E* engine, T mean,
- T stddev) {
- using NativeT = typename primitive_util::PrimitiveTypeToNative<type>::type;
- std::normal_distribution<NativeT> generator(mean, stddev);
- return CreateRandomLiteral<type, NativeT>(
- shape, [&](tensorflow::gtl::ArraySlice<int64> /*indexes*/) {
- return generator(*engine);
- });
-}
-
-template <PrimitiveType type, typename T>
-/* static */ StatusOr<std::unique_ptr<Literal>>
-LiteralTestUtil::CreateRandomLiteral(const Shape& shape, T mean, T stddev) {
- std::minstd_rand0 engine;
- return CreateRandomLiteral<type>(shape, &engine, mean, stddev);
+ EXPECT_TRUE(Near(*Literal::CreateR4FromArray4D(expected), actual, error));
}
} // namespace xla
std::unique_ptr<Literal> literal = Literal::MakeTuple({
Literal::CreateR0<int32>(42).get(), Literal::CreateR0<int32>(64).get(),
});
- LiteralTestUtil::ExpectEqual(*literal, *literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *literal));
}
TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
}
}
+TEST(LiteralTestUtilTest, NotEqualHasValuesInMessage) {
+ auto expected = Literal::CreateR1<int32>({1, 2, 3});
+ auto actual = Literal::CreateR1<int32>({4, 5, 6});
+ ::testing::AssertionResult result =
+ LiteralTestUtil::Equal(*expected, *actual);
+ EXPECT_THAT(result.message(), ::testing::HasSubstr("expected: {1, 2, 3}"));
+ EXPECT_THAT(result.message(), ::testing::HasSubstr("actual: {4, 5, 6}"));
+}
+
TEST(LiteralTestUtilTest, NearComparatorR1) {
auto a =
Literal::CreateR1<float>({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8});
expect.PopulateWithValue<float>(size * 1.5f * 3.5f);
auto actual = ExecuteAndTransfer(
std::move(hlo_module), {Literal::CreateR0<float>(-9.0f).get(), &arg1});
- LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
void RunTest1D(bool manual_fusion, int size) {
Literal expect = std::move(*Literal::CreateR1<float>({size * 1.5f * 3.5f}));
auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1});
- LiteralTestUtil::ExpectNear(expect, *actual, error_spec_);
+ EXPECT_TRUE(LiteralTestUtil::Near(expect, *actual, error_spec_));
}
};
&execution_options_));
}
- LiteralTestUtil::ExpectEqual(*result1, *result2);
- LiteralTestUtil::ExpectEqual(*result1, *result3);
- LiteralTestUtil::ExpectNotEqual(*result1, *result4);
- LiteralTestUtil::ExpectNotEqual(*result4, *result5);
- LiteralTestUtil::ExpectNotEqual(*result5, *result6);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result2));
+ EXPECT_TRUE(LiteralTestUtil::Equal(*result1, *result3));
+ EXPECT_FALSE(LiteralTestUtil::Equal(*result1, *result4));
+ EXPECT_FALSE(LiteralTestUtil::Equal(*result4, *result5));
+ EXPECT_FALSE(LiteralTestUtil::Equal(*result5, *result6));
}
XLA_TEST_F(PrngTest, TenValuesN01) {
std::unique_ptr<Literal> expected =
Literal::CreateR2FromArray2D<float>(expected_array);
if (use_bfloat16()) {
- expected = LiteralTestUtil::ConvertF32ToBF16(*expected);
+ expected = Literal::ConvertF32ToBF16(*expected);
}
- LiteralTestUtil::ExpectEqual(*expected, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected, *actual));
}
XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal);
+ Literal::ReshapeSlice({2, 1}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal);
+ Literal::ReshapeSlice({4, 2}, {1, 0}, *input_literal);
ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
zero_error_spec_);
}
// Since the reshape is a no-op, verify that it does not change the underlying
// data.
if (use_bfloat16()) {
- auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal);
+ auto expected = Literal::ConvertF32ToBF16(*input_literal);
EXPECT_EQ(expected->data<bfloat16>(), output_literal->data<bfloat16>());
} else {
EXPECT_EQ(input_literal->data<float>(), output_literal->data<float>());
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {2, 3, 1, 0}, *input_literal)
->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0}));
// Specify the requested output shape explicitly to ensure that this reshape
/*new_sizes=*/new_bounds);
std::unique_ptr<Literal> expected =
- LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal)
+ Literal::ReshapeSlice(new_bounds, {1, 0, 2, 3}, *input_literal)
->Relayout(input_literal->shape().layout());
// Specify the requested output shape explicitly to ensure that this reshape
EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
}
TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
EXPECT_EQ(46.0f, actual->Get<float>({1, 1}));
std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
- LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*round_tripped, *actual));
}
} // namespace
client_->TransferToServer(original).ConsumeValueOrDie();
std::unique_ptr<Literal> result =
client_->Transfer(*data).ConsumeValueOrDie();
- LiteralTestUtil::ExpectEqual(original, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(original, *result));
}
};
&execution_options_)
.ConsumeValueOrDie();
auto expected_literal = Literal::CreateR0<uint32>(dividend / divisor);
- LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
}
&execution_options_)
.ConsumeValueOrDie();
auto expected_literal = Literal::CreateR0<uint32>(dividend % divisor);
- LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
}
}
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferEmptyTuple) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferNestedTuple) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValue) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) {
transfer_manager_->TransferLiteralFromDevice(
stream_executor_, device_buffer));
- LiteralTestUtil::ExpectEqual(*literal, *result);
+ EXPECT_TRUE(LiteralTestUtil::Equal(*literal, *result));
}
} // namespace