return result;
}
+StatusOr<std::unique_ptr<Literal>> LiteralBase::Broadcast(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions) const {
+ if (!ShapeUtil::IsArray(shape())) {
+ return InvalidArgument("Broadcast only supports arrays.");
+ }
+
+ for (int64 i = 0; i < dimensions.size(); i++) {
+ TF_RET_CHECK(shape().dimensions(i) ==
+ result_shape.dimensions(dimensions[i]));
+ }
+
+ std::unique_ptr<Literal> result = MakeUnique<Literal>(result_shape);
+
+ // scratch_source_index is temporary storage space for the computed index into
+ // the input literal. We put it here to avoid allocating an std::vector in
+ // every iteration of ShapeUtil::ForEachIndex.
+ std::vector<int64> scratch_source_index(shape().dimensions_size());
+
+ char* dest_data = static_cast<char*>(result->untyped_data());
+ const char* source_data = static_cast<const char*>(untyped_data());
+ const int64 primitive_size =
+ ShapeUtil::ByteSizeOfPrimitiveType(shape().element_type());
+
+ ShapeUtil::ForEachIndex(
+ result_shape, [&](tensorflow::gtl::ArraySlice<int64> output_index) {
+ for (int64 i = 0; i < dimensions.size(); ++i) {
+ scratch_source_index[i] = output_index[dimensions[i]];
+ }
+ int64 dest_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ result_shape, output_index);
+ int64 source_index = IndexUtil::MultidimensionalIndexToLinearIndex(
+ shape(), scratch_source_index);
+ memcpy(dest_data + primitive_size * dest_index,
+ source_data + primitive_size * source_index, primitive_size);
+ return true;
+ });
+
+ return std::move(result);
+}
+
StatusOr<std::unique_ptr<Literal>> LiteralBase::Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const {
if (!ShapeUtil::IsArray(shape())) {
StatusOr<std::unique_ptr<Literal>> Reshape(
tensorflow::gtl::ArraySlice<int64> dimensions) const;
+ // Creates a new literal by broadcasting this literal with `dimensions` to
+ // yield a literal of shape `result_shape`.
+ StatusOr<std::unique_ptr<Literal>> Broadcast(
+ const Shape& result_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
+
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
// in the original literal, and it specifies the order of the new dimensions
tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")"));
}
+TEST_F(LiteralUtilTest, BroadcastVectorToMatrix0) {
+ std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> broadcasted_literal,
+ literal->Broadcast(
+ /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{0}));
+ EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 1}, {2, 2}}));
+}
+
+TEST_F(LiteralUtilTest, BroadcastVectorToMatrix1) {
+ std::unique_ptr<Literal> literal = Literal::CreateR1<int64>({1, 2});
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> broadcasted_literal,
+ literal->Broadcast(
+ /*result_shape=*/ShapeUtil::MakeShape(S64, {2, 2}),
+ /*dimensions=*/{1}));
+ EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int64>({{1, 2}, {1, 2}}));
+}
+
+TEST_F(LiteralUtilTest, BroadcastScalarToMatrix) {
+ std::unique_ptr<Literal> literal = Literal::CreateR0<int32>(9);
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<Literal> broadcasted_literal,
+ literal->Broadcast(
+ /*result_shape=*/ShapeUtil::MakeShape(S32, {2, 2}),
+ /*dimensions=*/{}));
+ EXPECT_EQ(*broadcasted_literal, *Literal::CreateR2<int32>({{9, 9}, {9, 9}}));
+}
+
} // namespace
} // namespace xla
return Status::OK();
}
+Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
+ const Literal& operand = GetEvaluatedLiteralFor(broadcast->operand(0));
+
+ TF_RET_CHECK(broadcast->dimensions().size() ==
+ ShapeUtil::Rank(operand.shape()))
+ << "broadcast dimensions is of size: " << broadcast->dimensions().size()
+ << " and rank of operand_to_broadcast is: "
+ << ShapeUtil::Rank(operand.shape());
+ // Checks that operand's dimensions are the same as the broadcast's
+ // dimensions along the dimensions to be broadcasted.
+ for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
+ TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
+ operand.shape().dimensions(i));
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ evaluated_[broadcast],
+ operand.Broadcast(broadcast->shape(), broadcast->dimensions()));
+
+ return Status::OK();
+}
+
Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const auto result_shape = get_tuple_element->shape();
const int64 index = get_tuple_element->tuple_index();
Status HandleSelect(HloInstruction* select) override;
+ Status HandleBroadcast(HloInstruction* broadcast) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
return HandleRound<ReturnT>(round);
}
- Status HandleBroadcast(HloInstruction* broadcast) override {
- const Literal& operand_to_broadcast =
- parent_->GetEvaluatedLiteralFor(broadcast->operand(0));
- std::vector<int64> broadcast_indices(
- ShapeUtil::Rank(broadcast->operand(0)->shape()), 0);
-
- TF_RET_CHECK(broadcast->dimensions().size() ==
- ShapeUtil::Rank(operand_to_broadcast.shape()))
- << "broadcast dimensions is of size: " << broadcast->dimensions().size()
- << " and rank of operand_to_broadcast is: "
- << ShapeUtil::Rank(operand_to_broadcast.shape());
- // Checks that operand's dimensions are the same as the broadcast's
- // dimensions along the dimensions to be broadcasted.
- for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- TF_RET_CHECK(broadcast->shape().dimensions(broadcast->dimensions(i)) ==
- operand_to_broadcast.shape().dimensions(i));
- }
-
- auto output = MakeUnique<Literal>(broadcast->shape());
- TF_RETURN_IF_ERROR(output->Populate<ReturnT>(
- [&](tensorflow::gtl::ArraySlice<int64> multi_index) {
- for (int64 i = 0; i < broadcast->dimensions().size(); ++i) {
- broadcast_indices[i] = multi_index[broadcast->dimensions(i)];
- }
- return operand_to_broadcast.Get<ReturnT>(broadcast_indices);
- }));
- parent_->evaluated_[broadcast] = std::move(output);
- return Status::OK();
- }
-
template <
typename NativeT,
typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
.IgnoreError();
}
+ // These convenience wrappers don't take `base`, `count` and `incr`
+ // explicitly, but iterate over every element in `shape` instead.
+
+ template <typename FnType>
+ static Status ForEachIndexWithStatus(const Shape& shape,
+ const FnType& visitor_function) {
+ std::vector<int64> base(shape.dimensions_size());
+ std::vector<int64> incr(shape.dimensions_size(), 1);
+ return ForEachIndexWithStatus(shape, base,
+ /*count=*/AsInt64Slice(shape.dimensions()),
+ incr, visitor_function);
+ }
+
+ template <typename FnType>
+ static void ForEachIndex(const Shape& shape, const FnType& visitor_function) {
+ ForEachIndexWithStatus(shape,
+ [&](tensorflow::gtl::ArraySlice<int64> indices) {
+ return StatusOr<bool>(visitor_function(indices));
+ })
+ .IgnoreError();
+ }
+
// A parallel version of ForEachIndex(WithStatus). This can only be used if
// the visitor_function is thread-safe and the order of iteration does not
// matter.