/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
- HloModule* module, const HloComputationProto& proto,
+ const HloComputationProto& proto,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
- std::vector<std::unique_ptr<HloInstruction>> instructions;
tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
+ tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id;
+ std::vector<std::unique_ptr<HloInstruction>> instructions;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloInstruction> instruction,
- HloInstruction::CreateFromProto(module, instruction_proto,
- instruction_map, computation_map));
+ HloInstruction::CreateFromProto(instruction_proto, instruction_map,
+ computation_map));
if (instruction->opcode() == HloOpcode::kParameter) {
parameter_count++;
}
TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
instruction_map[instruction_proto.id()] = instruction.get();
+ to_proto_id[instruction.get()] = instruction_proto.id();
instructions.push_back(std::move(instruction));
}
TF_RET_CHECK(proto.root_id() != -1);
TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
HloInstruction* root = instruction_map.at(proto.root_id());
+
+ // Sort the instructions in the proto id's order.
+ std::sort(instructions.begin(), instructions.end(),
+ [&](const std::unique_ptr<HloInstruction>& a,
+ const std::unique_ptr<HloInstruction>& b) {
+ return to_proto_id[a.get()] < to_proto_id[b.get()];
+ });
+
return WrapUnique(new HloComputation(proto.name(), parameter_count,
&instructions, root,
/*fusion_instruction=*/nullptr));
// Creates a computation from the given proto. Arguments:
//
- // module: the module which will contain the computation. The newly created
- // computation is *not* added to the module, however.
// proto: the proto to convert from.
// computation_map: a map from computation id to HloComputation*. This map
// must contain all computations which the newly constructed computation
// calls.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
- HloModule* module, const HloComputationProto& proto,
+ const HloComputationProto& proto,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
// Gets the instructions in this computation.
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
- HloModule* module, const HloInstructionProto& proto,
+ const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
proto.add_fft_length(fft_len);
}
+ if (has_sharding()) {
+ *proto.mutable_sharding() = sharding().ToProto();
+ }
+
proto.set_channel_name(channel_name_);
proto.set_cost_estimate_ns(cost_estimate_ns_);
// Creates an instruction from the given proto. Arguments:
//
- // module: the module which will contain the instruction. The newly created
- // instruction is *not* added to the module or any computation, however.
// proto: the proto to convert from.
// instruction_map: a map from instruction id to HloInstruction*. This map
// must contain all operands of the newly constructed instruction.
// must contain all computations which the newly constructed instruction
// calls.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
- HloModule* module, const HloInstructionProto& proto,
+ const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
<< ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
- auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
- module_config);
-
tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
+ tensorflow::gtl::FlatMap<HloComputation*, int64> to_proto_id;
+ std::vector<std::unique_ptr<HloComputation>> computations;
+ HloComputation* entry = nullptr;
for (const HloComputationProto& computation_proto : proto.computations()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
- HloComputation::CreateFromProto(
- module.get(), computation_proto, computation_map));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloComputation> computation,
+ HloComputation::CreateFromProto(computation_proto, computation_map));
CHECK_NE(computation.get(), nullptr);
int64 computation_id = computation_proto.id();
TF_RET_CHECK(computation_id != -1);
TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
+ computation_map[computation_id] = computation.get();
+ to_proto_id[computation.get()] = computation_id;
+ if (computation_id == proto.entry_computation_id()) {
+ entry = computation.get();
+ }
+ computations.push_back(std::move(computation));
+ }
+ TF_RET_CHECK(entry != nullptr);
+
+ auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
+ module_config);
+
+ // Sort the computations in the proto id's order.
+ std::sort(computations.begin(), computations.end(),
+ [&](const std::unique_ptr<HloComputation>& a,
+ const std::unique_ptr<HloComputation>& b) {
+ return to_proto_id[a.get()] < to_proto_id[b.get()];
+ });
+
+ // Add sorted computations to the module.
+ for (auto& computation : computations) {
+ bool is_entry = computation.get() == entry;
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
- computation_map[computation_id] = module->AddComputationInternal(
- std::move(computation),
- /*is_entry=*/proto.entry_computation_id() == computation_id,
- /*uniquify_names=*/false);
+ module->AddComputationInternal(std::move(computation), is_entry,
+ /*uniquify_names=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);