Allow to generate fake infeed buffers with shapes derived from the computation.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 24 May 2018 09:54:37 +0000 (02:54 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 09:57:46 +0000 (02:57 -0700)
When replaying a computation from a HloSnapshot, we want to be able to provide fake
infeed data. This was already possible when the infeed shape is known by providing
it with the --fake_infeed_shape flag. With this CL, we add the option to derive it
from the provided HloSnapshot. Also, we transfer the infeed shape a fixed number of
times instead of infinitely many times (configurable with a flag).
Otherwise we will definitely run out of memory at some point.

PiperOrigin-RevId: 197863412

tensorflow/compiler/xla/tools/replay_computation.cc

index df05013..d641ddf 100644 (file)
@@ -63,6 +63,8 @@ namespace {
 // fields.
 struct Options {
   string fake_infeed_shape;
+  bool generate_fake_infeed = false;
+  int num_infeeds = 10;
   bool use_fake_data = false;
   bool print_result = true;
   int num_runs = 1;
@@ -72,8 +74,12 @@ struct Options {
 // Invokes the given computation passing arbitrary data for every (unbound)
 // parameter if use_fake_data, Otherwise use recorded data if available.
 //
-// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
-// otherwise, no infeed is performed.
+// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided.
+// If generate_fake_infeed is true, the required infeed shape is derived from
+// the computation and then used to provide a fake infeed shape.
+//
+// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
+// no infeed is performed.
 StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
                                                      Client* client,
                                                      const Options& opts) {
@@ -92,22 +98,48 @@ StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
     }
   }
 
+  bool provide_infeed = false;
+  Shape infeed_shape;
+  if (!opts.fake_infeed_shape.empty()) {
+    StatusOr<Shape> shape_status =
+        ShapeUtil::ParseShapeString(opts.fake_infeed_shape);
+    TF_CHECK_OK(shape_status.status());
+    infeed_shape = std::move(shape_status).ValueOrDie();
+    provide_infeed = true;
+  } else if (opts.generate_fake_infeed) {
+    for (const auto& comp : computation.proto().computations()) {
+      for (const auto& instruction : comp.instructions()) {
+        if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) {
+          CHECK(!provide_infeed)
+              << "--generate_fake_infeed only works if the model has 0 or 1 "
+                 "infeed ops, but this one has >= 2.";
+          provide_infeed = true;
+          infeed_shape = instruction.shape();
+          LOG(INFO) << "Generating fake infeed shape for inferred shape: "
+                    << ShapeUtil::HumanString(infeed_shape);
+        }
+      }
+    }
+  }
   // We only instantiate the thread pool if the user has requested that a
-  // concurrent infeed occur via the fake_infeed_shape.
+  // concurrent infeed occur via the fake_infeed_shape, or when
+  // --generate_fake_infeed is passed and there exists an infeed operation in
+  // the HloSnapshot.
   tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
-
-  if (!opts.fake_infeed_shape.empty()) {
+  if (provide_infeed) {
     pool.emplace(tensorflow::Env::Default(), "infeed",
                  /*num_threads=*/1);
-    pool->Schedule([opts, client]() {
-      StatusOr<Shape> shape_status =
-          ShapeUtil::ParseShapeString(opts.fake_infeed_shape);
-      TF_CHECK_OK(shape_status.status());
-      Shape shape = std::move(shape_status).ValueOrDie();
-      StatusOr<std::unique_ptr<Literal>> data_status = MakeFakeLiteral(shape);
+    pool->Schedule([opts, infeed_shape, client]() {
+      StatusOr<std::unique_ptr<Literal>> data_status =
+          MakeFakeLiteral(infeed_shape);
       TF_CHECK_OK(data_status.status());
       std::unique_ptr<Literal> data = std::move(data_status).ValueOrDie();
-      while (true) {
+      // There may be several infeed buffers needed, however we don't know how
+      // many. If we proactively transfer too many infeed buffers, we may run
+      // out of memory. If we transfer too few infeed buffers, the program will
+      // hang.
+      // TODO(akuegel): Figure out a better way to handle this.
+      for (int i = 0; i < opts.num_infeeds; ++i) {
         TF_CHECK_OK(client->TransferToInfeed(*data));
       }
     });
@@ -202,8 +234,13 @@ int main(int argc, char** argv) {
                        "Print the result of the computation to stdout"),
       tensorflow::Flag("num_runs", &opts.num_runs,
                        "Number of times to run each computation"),
+      tensorflow::Flag("num_infeeds", &opts.num_infeeds,
+                       "Number of times we transfer the fake infeed data"),
       tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape,
                        "Shape of fake data to construct for (infinite) infeed"),
+      tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
+                       "Whether a fake infeed shape should be generated "
+                       "derived from the computation"),
       tensorflow::Flag(
           "xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
           "Pass --xla_hlo_profile the last time we run the computation."),