// fields.
struct Options {
string fake_infeed_shape;
+ bool generate_fake_infeed = false;
+ int num_infeeds = 10;
bool use_fake_data = false;
bool print_result = true;
int num_runs = 1;
// Invokes the given computation passing arbitrary data for every (unbound)
// parameter if use_fake_data, Otherwise use recorded data if available.
//
-// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided;
-// otherwise, no infeed is performed.
+// Similarly, infeeds fake data of shape fake_infeed_shape if it is provided.
+// If generate_fake_infeed is true, the required infeed shape is derived from
+// the computation and then used to provide a fake infeed shape.
+//
+// If neither generate_fake_infeed is true nor a fake_infeed_shape is provided,
+// no infeed is performed.
StatusOr<std::unique_ptr<Literal>> ReplayComputation(const HloSnapshot& module,
Client* client,
const Options& opts) {
}
}
+ bool provide_infeed = false;
+ Shape infeed_shape;
+ if (!opts.fake_infeed_shape.empty()) {
+ StatusOr<Shape> shape_status =
+ ShapeUtil::ParseShapeString(opts.fake_infeed_shape);
+ TF_CHECK_OK(shape_status.status());
+ infeed_shape = std::move(shape_status).ValueOrDie();
+ provide_infeed = true;
+ } else if (opts.generate_fake_infeed) {
+ for (const auto& comp : computation.proto().computations()) {
+ for (const auto& instruction : comp.instructions()) {
+ if (instruction.opcode() == HloOpcodeString(HloOpcode::kInfeed)) {
+ CHECK(!provide_infeed)
+ << "--generate_fake_infeed only works if the model has 0 or 1 "
+ "infeed ops, but this one has >= 2.";
+ provide_infeed = true;
+ infeed_shape = instruction.shape();
+ LOG(INFO) << "Generating fake infeed shape for inferred shape: "
+ << ShapeUtil::HumanString(infeed_shape);
+ }
+ }
+ }
+ }
// We only instantiate the thread pool if the user has requested that a
- // concurrent infeed occur via the fake_infeed_shape.
+ // concurrent infeed occur via the fake_infeed_shape, or when
+ // --generate_fake_infeed is passed and there exists an infeed operation in
+ // the HloSnapshot.
tensorflow::gtl::optional<tensorflow::thread::ThreadPool> pool;
-
- if (!opts.fake_infeed_shape.empty()) {
+ if (provide_infeed) {
pool.emplace(tensorflow::Env::Default(), "infeed",
/*num_threads=*/1);
- pool->Schedule([opts, client]() {
- StatusOr<Shape> shape_status =
- ShapeUtil::ParseShapeString(opts.fake_infeed_shape);
- TF_CHECK_OK(shape_status.status());
- Shape shape = std::move(shape_status).ValueOrDie();
- StatusOr<std::unique_ptr<Literal>> data_status = MakeFakeLiteral(shape);
+ pool->Schedule([opts, infeed_shape, client]() {
+ StatusOr<std::unique_ptr<Literal>> data_status =
+ MakeFakeLiteral(infeed_shape);
TF_CHECK_OK(data_status.status());
std::unique_ptr<Literal> data = std::move(data_status).ValueOrDie();
- while (true) {
+ // There may be several infeed buffers needed, however we don't know how
+ // many. If we proactively transfer too many infeed buffers, we may run
+ // out of memory. If we transfer too few infeed buffers, the program will
+ // hang.
+ // TODO(akuegel): Figure out a better way to handle this.
+ for (int i = 0; i < opts.num_infeeds; ++i) {
TF_CHECK_OK(client->TransferToInfeed(*data));
}
});
"Print the result of the computation to stdout"),
tensorflow::Flag("num_runs", &opts.num_runs,
"Number of times to run each computation"),
+ tensorflow::Flag("num_infeeds", &opts.num_infeeds,
+ "Number of times we transfer the fake infeed data"),
tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape,
"Shape of fake data to construct for (infinite) infeed"),
+ tensorflow::Flag("generate_fake_infeed", &opts.generate_fake_infeed,
+ "Whether a fake infeed shape should be generated "
+ "derived from the computation"),
tensorflow::Flag(
"xla_hlo_profile_last_run", &opts.xla_hlo_profile_last_run,
"Pass --xla_hlo_profile the last time we run the computation."),