using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
-HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form,
+HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
bool bitcast_defines_value)
: module_(module),
ssa_form_(ssa_form),
bitcast_defines_value_(bitcast_defines_value),
- call_graph_(CallGraph::Build(module)) {}
+ call_graph_(CallGraph::Build(&module)) {}
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
}
string HloDataflowAnalysis::ToString() const {
- string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
+ string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
StrAppend(&out, " Instruction value sets:\n");
- for (const HloComputation* computation : module_->computations()) {
+ for (const HloComputation* computation : module_.computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
StrAppend(&out, " ", instruction->name(), ":\n");
if (ShapeUtil::IsTuple(instruction->shape())) {
}
};
- for (HloComputation* computation : module_->computations()) {
+ for (HloComputation* computation : module_.computations()) {
for (HloInstruction* instruction : computation->instructions()) {
add_to_worklist(instruction);
}
}
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
- for (const HloComputation* computation : module_->computations()) {
+ for (const HloComputation* computation : module_.computations()) {
const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
for (HloInstruction* instruction : computation->instructions()) {
// Create an empty shape tree.
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
- HloModule* module, bool ssa_form, bool bitcast_defines_value) {
- VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name();
- XLA_VLOG_LINES(2, module->ToString());
+ const HloModule& module, bool ssa_form, bool bitcast_defines_value) {
+ VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
+ XLA_VLOG_LINES(2, module.ToString());
auto dataflow_analysis = WrapUnique(
new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
// lookup is faster.
std::vector<std::vector<HloPosition>> value_positions(
dataflow_analysis->next_value_id_);
- for (const HloComputation* computation : module->computations()) {
+ for (const HloComputation* computation : module.computations()) {
for (HloInstruction* instruction : computation->instructions()) {
for (const auto& pair :
dataflow_analysis->GetInstructionValueSet(instruction)) {
// For each value in each value set, verify that the value set's position
// appears in the value's positions().
- for (const auto& computation : module_->computations()) {
+ for (const auto& computation : module_.computations()) {
for (const auto& instruction : computation->instructions()) {
for (const auto& pair : GetInstructionValueSet(instruction)) {
const ShapeIndex& index = pair.first;
// a new HLO value in the analysis. If false then Bitcast forwards the
// value of its operand.
static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
- HloModule* module, bool ssa_form = false,
+ const HloModule& module, bool ssa_form = false,
bool bitcast_defines_value = false);
// Returns true if 'instruction' defines an HLO value at the given shape index
string ToString() const;
protected:
- HloDataflowAnalysis(HloModule* module, bool ssa_form,
+ HloDataflowAnalysis(const HloModule& module, bool ssa_form,
bool bitcast_defines_value = false);
// Returns a new HloValue defined at the given instruction and shape index.
// Verify various invariants of the dataflow analysis.
Status Verify() const;
- HloModule* const module_;
+ const HloModule& module_;
const bool ssa_form_;
const bool bitcast_defines_value_;