namespace xla {
namespace {
+// Calculates the number of bytes required to store the data within the
+// specified shape. In case of a (nested) tuple shape this is the total byte
+// size of all sub-shapes within the tuple.
+int64 DataSizeOfShape(const Shape& shape) {
+ if (ShapeUtil::IsArray(shape)) {
+ return ShapeUtil::ByteSizeOf(shape);
+ }
+
+ int64 total_size = 0;
+ for (const Shape& s : shape.tuple_shapes()) {
+ total_size += DataSizeOfShape(s);
+ }
+ return total_size;
+}
+
+// Create a ComputationDataHandle for an op what generates fake data with the
+// given shape.
+ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape,
+ ComputationBuilder* builder) {
+ if (ShapeUtil::IsArray(shape)) {
+ return builder->Broadcast(
+ builder->ConstantLiteral(Literal::One(shape.element_type())),
+ AsInt64Slice(shape.dimensions()));
+ }
+ std::vector<ComputationDataHandle> parts;
+ for (const Shape& s : shape.tuple_shapes()) {
+ parts.push_back(BuildFakeDataOpOnDevice(s, builder));
+ }
+ return builder->Tuple(parts);
+}
+
std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
Client* client) {
ComputationBuilder b(
client,
tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
- // TODO(b/26811613): Replace this when RNG is supported on all backends.
- b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())),
- AsInt64Slice(shape.dimensions()));
+ BuildFakeDataOpOnDevice(shape, &b);
Computation computation = b.Build().ConsumeValueOrDie();
auto execution_options = CreateDefaultExecutionOptions();
std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
Client* client) {
- if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) {
+ if (DataSizeOfShape(shape) < (1LL << 20)) {
StatusOr<std::unique_ptr<Literal>> literal_status = MakeFakeLiteral(shape);
if (!literal_status.ok()) {
// If we got an Unimplemented error, fall back to making the fake data via