[XLA] Pass the module to HloDataflowAnalysis by const reference.
authorMichael Kuperstein <mkuper@google.com>
Sat, 17 Feb 2018 02:13:53 +0000 (18:13 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 17 Feb 2018 02:17:31 +0000 (18:17 -0800)
PiperOrigin-RevId: 186072673

tensorflow/compiler/xla/service/copy_insertion.cc
tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
tensorflow/compiler/xla/service/hlo_alias_analysis.cc
tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
tensorflow/compiler/xla/service/hlo_ordering_test.cc
tensorflow/compiler/xla/service/liveness_util_test.cc
tensorflow/compiler/xla/tests/test_utils.cc

index c812df42355fb93796df966d8cab4c8437cdfe2e..cc195879a6bb490a9b49ad962aa9326cb51d9b0a 100644 (file)
@@ -1156,7 +1156,7 @@ bool IsWhileBody(const HloComputation* computation,
     HloModule* module) {
   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
-                      HloDataflowAnalysis::Run(module));
+                      HloDataflowAnalysis::Run(*module));
 
   bool changed = false;
 
index 916b556fd43a453a4da2c96217e74c367f8c7653..9db85bc788bde46c890a46ce9b0902ddce3f5675 100644 (file)
@@ -49,7 +49,7 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
   TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module));
 
   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
-                      HloDataflowAnalysis::Run(module));
+                      HloDataflowAnalysis::Run(*module));
 
   // Make sure all operands of a library call are in memory instead of constants
   // in IR.
index 6d2a3aa5b531650a658502531e050702ffbd3760..30e32a46d7dd0923f738939c33407ac7484b5bbe 100644 (file)
@@ -419,7 +419,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
   auto alias_analysis = WrapUnique(new HloAliasAnalysis(module));
   TF_ASSIGN_OR_RETURN(
       alias_analysis->dataflow_analysis_,
-      HloDataflowAnalysis::Run(module, /*ssa_form=*/true,
+      HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
                                /*bitcast_defines_value=*/false));
 
   BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
index ccbbe8f1966d59b4ab2904dcc6ea724aaf4a7603..934e43ba4879628362009267c671ec4cb0d79c52 100644 (file)
@@ -38,12 +38,12 @@ namespace xla {
 using ::tensorflow::strings::StrAppend;
 using ::tensorflow::strings::StrCat;
 
-HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form,
+HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
                                          bool bitcast_defines_value)
     : module_(module),
       ssa_form_(ssa_form),
       bitcast_defines_value_(bitcast_defines_value),
-      call_graph_(CallGraph::Build(module)) {}
+      call_graph_(CallGraph::Build(&module)) {}
 
 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
                                            const ShapeIndex& index) const {
@@ -115,9 +115,9 @@ void HloDataflowAnalysis::DeleteMarkedValues() {
 }
 
 string HloDataflowAnalysis::ToString() const {
-  string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
+  string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
   StrAppend(&out, "  Instruction value sets:\n");
-  for (const HloComputation* computation : module_->computations()) {
+  for (const HloComputation* computation : module_.computations()) {
     for (const HloInstruction* instruction : computation->instructions()) {
       StrAppend(&out, "    ", instruction->name(), ":\n");
       if (ShapeUtil::IsTuple(instruction->shape())) {
@@ -592,7 +592,7 @@ void HloDataflowAnalysis::Propagate() {
     }
   };
 
-  for (HloComputation* computation : module_->computations()) {
+  for (HloComputation* computation : module_.computations()) {
     for (HloInstruction* instruction : computation->instructions()) {
       add_to_worklist(instruction);
     }
@@ -686,7 +686,7 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
 }
 
 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
-  for (const HloComputation* computation : module_->computations()) {
+  for (const HloComputation* computation : module_.computations()) {
     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
     for (HloInstruction* instruction : computation->instructions()) {
       // Create an empty shape tree.
@@ -787,9 +787,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
 
 /* static */
 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
-    HloModule* module, bool ssa_form, bool bitcast_defines_value) {
-  VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name();
-  XLA_VLOG_LINES(2, module->ToString());
+    const HloModule& module, bool ssa_form, bool bitcast_defines_value) {
+  VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
+  XLA_VLOG_LINES(2, module.ToString());
 
   auto dataflow_analysis = WrapUnique(
       new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
@@ -806,7 +806,7 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
   // lookup is faster.
   std::vector<std::vector<HloPosition>> value_positions(
       dataflow_analysis->next_value_id_);
-  for (const HloComputation* computation : module->computations()) {
+  for (const HloComputation* computation : module.computations()) {
     for (HloInstruction* instruction : computation->instructions()) {
       for (const auto& pair :
            dataflow_analysis->GetInstructionValueSet(instruction)) {
@@ -858,7 +858,7 @@ Status HloDataflowAnalysis::Verify() const {
 
   // For each value in each value set, verify that the value set's position
   // appears in the value's positions().
-  for (const auto& computation : module_->computations()) {
+  for (const auto& computation : module_.computations()) {
     for (const auto& instruction : computation->instructions()) {
       for (const auto& pair : GetInstructionValueSet(instruction)) {
         const ShapeIndex& index = pair.first;
index 89d318188f0855c7924836a51cfe98d531e08cb4..7b8a74b096ff48733717e78ada5bb56a28caed72 100644 (file)
@@ -60,7 +60,7 @@ class HloDataflowAnalysis {
   //     a new HLO value in the analysis. If false then Bitcast forwards the
   //     value of its operand.
   static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
-      HloModule* module, bool ssa_form = false,
+      const HloModule& module, bool ssa_form = false,
       bool bitcast_defines_value = false);
 
   // Returns true if 'instruction' defines an HLO value at the given shape index
@@ -119,7 +119,7 @@ class HloDataflowAnalysis {
   string ToString() const;
 
  protected:
-  HloDataflowAnalysis(HloModule* module, bool ssa_form,
+  HloDataflowAnalysis(const HloModule& module, bool ssa_form,
                       bool bitcast_defines_value = false);
 
   // Returns a new HloValue defined at the given instruction and shape index.
@@ -180,7 +180,7 @@ class HloDataflowAnalysis {
   // Verify various invariants of the dataflow analysis.
   Status Verify() const;
 
-  HloModule* const module_;
+  const HloModule& module_;
   const bool ssa_form_;
   const bool bitcast_defines_value_;
 
index e714b2567fd1b3eab607a19f0bb7e3288150dc64..7bf3a1a06045c79621d75b653bf42220705a69d4 100644 (file)
@@ -50,7 +50,7 @@ class HloDataflowAnalysisTest : public HloTestBase,
                                          bool bitcast_defines_value = false) {
     hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis");
     analysis_ =
-        HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
+        HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
             .ConsumeValueOrDie();
     return *analysis_;
   }
index aba66114de649ce7667ae77174e9c4073b010b90..a989fce63234cb860d08c48b02462e96bec879bc 100644 (file)
@@ -262,8 +262,8 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
       scalar_shape, HloOpcode::kAdd, constant, xla_while));
   module->AddEntryComputation(builder.Build());
 
-  TF_ASSERT_OK_AND_ASSIGN(
-      auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
+  TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+                          HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
   DependencyHloOrdering ordering(module.get());
 
   // Init value is defined before the while, but live range is not before the
index 2c2a02f6375343d67dfb155bbb03729ff6e490d2..f8b309488eeb5391b1cad5db760934ec1f7e3521 100644 (file)
@@ -35,8 +35,7 @@ class PointsToAnalysisTestBase : public HloTestBase {
     CHECK_NOTNULL(module_.get());
     points_to_analysis_ =
         TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
-    dataflow_analysis_ =
-        HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie();
+    dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie();
   }
 
   void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
index b060fb13b1451aab30cfca73bea0a4a598a9fa3a..0bc7df2a65b44a76f877b6513e6bf93b99fbc1a3 100644 (file)
@@ -287,7 +287,7 @@ StatusOr<std::unique_ptr<Literal>> MakeFakeLiteral(const Shape& shape) {
 
 StatusOr<std::vector<std::unique_ptr<Literal>>> MakeFakeArguments(
     HloModule* const module) {
-  TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module));
+  TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module));
   const auto params = module->entry_computation()->parameter_instructions();
   std::minstd_rand0 engine;
   std::vector<std::unique_ptr<Literal>> arguments(params.size());