Add infrastructure for a backend-specific configuration for each op. This is intentio...
authorBjarke Hammersholt Roune <broune@google.com>
Fri, 4 May 2018 23:51:06 +0000 (16:51 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sat, 5 May 2018 15:33:28 +0000 (08:33 -0700)
PiperOrigin-RevId: 195493500

14 files changed:
tensorflow/compiler/xla/service/hlo.proto
tensorflow/compiler/xla/service/hlo_computation.cc
tensorflow/compiler/xla/service/hlo_computation.h
tensorflow/compiler/xla/service/hlo_graph_dumper.cc
tensorflow/compiler/xla/service/hlo_graph_dumper.h
tensorflow/compiler/xla/service/hlo_instruction.cc
tensorflow/compiler/xla/service/hlo_instruction.h
tensorflow/compiler/xla/service/hlo_module.cc
tensorflow/compiler/xla/service/hlo_module.h
tensorflow/compiler/xla/service/hlo_verifier.cc
tensorflow/compiler/xla/statusor.h
tensorflow/compiler/xla/statusor_test.cc
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc

index aa68608..1f7c1cf 100644 (file)
@@ -147,6 +147,9 @@ message HloInstructionProto {
   repeated int64 called_computation_ids = 38;
 
   xla.OpSharding sharding = 40;
+
+  // Backend configuration for the instruction. Has backend-specific meaning.
+  string backend_config = 43;
 }
 
 // Serialization of HloComputation.
index 594413e..17e43c3 100644 (file)
@@ -347,6 +347,11 @@ std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
   // To avoid special handling of this computation, cast away const of
   // 'this'. 'this' is immediately removed from the post order after
   // construction.
+  //
+  // TODO(b/78350259): This violates const-correctness, since while the original
+  // computation is not returned, we still retrieve non-const computations from
+  // a const one. Consider also avoiding const for HloComputation, or review XLA
+  // for const-correctness of non-HloInstruction* types like this.
   ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
                               &post_order);
 
@@ -723,18 +728,25 @@ Status HloComputation::Accept(
   return this->Accept(&visitor);
 }
 
-std::unique_ptr<HloComputation> HloComputation::Clone(const string& suffix,
-                                                      HloModule* module) {
+std::unique_ptr<HloComputation> HloComputation::Clone(
+    const string& suffix, HloModule* module,
+    HloInstruction::CloneMap* clone_map) {
   return CloneWithReplacements(
       /*replacements=*/std::unordered_map<const HloInstruction*,
                                           std::unique_ptr<HloInstruction>>(),
-      module, suffix);
+      module, clone_map, suffix);
 }
 
 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
     std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
         replacements,
-    HloModule* module, const string& suffix) {
+    HloModule* module, HloInstruction::CloneMap* clone_map,
+    const string& suffix) {
+  HloInstruction::CloneMap local_clone_map;
+  if (clone_map == nullptr) {
+    clone_map = &local_clone_map;
+  }
+
   // Look up instr in the replacements map, and return either the replacement,
   // or instr, if the replacement isn't present.
   //
@@ -756,24 +768,19 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
     }
   }
 
-  std::unordered_map<HloInstruction*, HloInstruction*> clone_map;
   std::vector<std::unique_ptr<HloInstruction>> instructions;
   std::unique_ptr<HloInstruction> new_instr = nullptr;
   for (auto instr : postorder) {
     std::vector<HloInstruction*> new_operands;
     for (auto operand : instr->operands()) {
       auto replaced_operand = replace(operand);
-      // If replaced_operand is null, that means 'replacements' asked us not to
-      // include operand in the new computation.  But we can't do that, because
-      // operand is used by instr.
       CHECK_NE(replaced_operand, nullptr)
-          << "replacements map tried to eliminate a used instruction "
-          << operand->ToString() << ", used by " << instr->ToString();
-      new_operands.push_back(FindOrDie(clone_map, replaced_operand));
+          << "Replacements map specifies to leave out " << operand->ToString()
+          << ", but it is used by " << instr->ToString() << ".";
+      new_operands.push_back(FindOrDie(*clone_map, replaced_operand));
     }
-    new_instr =
-        instr->CloneWithNewOperands(instr->shape(), new_operands, module);
-    InsertOrDie(&clone_map, instr, new_instr.get());
+    new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands,
+                                            module, clone_map);
     instructions.push_back(std::move(new_instr));
   }
   Builder builder(name() + "." + suffix);
@@ -781,27 +788,24 @@ std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
     builder.AddInstruction(std::move(instr));
   }
   auto result = builder.Build(
-      /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction())));
+      /*root_instruction=*/FindOrDie(*clone_map, replace(root_instruction())));
 
   // Clone control dependencies.
   for (auto instr : postorder) {
-    HloInstruction* new_instr = FindOrDie(clone_map, instr);
+    HloInstruction* new_instr = FindOrDie(*clone_map, instr);
     for (auto successor : instr->control_successors()) {
       auto replaced_successor = replace(successor);
-
-      // successor may not be in clone_map, because it might have been
-      // removed by the replacements map.
-      if (replaced_successor == nullptr) {
-        continue;
-      }
+      CHECK_NE(replaced_successor, nullptr)
+          << "Replacements map specifies to leave out " << successor->ToString()
+          << ", but it is control-depended-on by " << instr->ToString() << ".";
 
       TF_CHECK_OK(new_instr->AddControlDependencyTo(
-          FindOrDie(clone_map, replaced_successor)));
+          FindOrDie(*clone_map, replaced_successor)));
     }
   }
 
   // We cloned the elements of 'replacements', so they're all going to be
-  // destroyed.  HloInstructions need to be detached from their operands before
+  // destroyed. HloInstructions need to be detached from their operands before
   // they're destroyed, otherwise they stick around in the operands' users lists
   // and cause use-after-frees.
   for (auto& kv : replacements) {
index 9d3f6e9..9898355 100644 (file)
@@ -291,11 +291,17 @@ class HloComputation {
       const std::function<Status(const HloInstruction*)>& visitor_func) const;
 
   // Returns a deep copy of this computation including all instructions.
-  // If the module pointer is not nullptr, it will be the module where
-  // the cloned computations will be added to (in order to support deep
-  // cloning).
-  std::unique_ptr<HloComputation> Clone(const string& suffix = "clone",
-                                        HloModule* module = nullptr);
+  //
+  // If the module pointer is not nullptr, then the cloned computations will be
+  // added to this module in order to support deep cloning. Otherwise the module
+  // of the computation is used.
+  //
+  // If clone_map is not nullptr, then each original instruction that is cloned
+  // will be inserted and map to its clone. clone_map should not already contain
+  // any of the instructions to clone.
+  std::unique_ptr<HloComputation> Clone(
+      const string& suffix = "clone", HloModule* module = nullptr,
+      HloInstruction::CloneMap* clone_map = nullptr);
 
   // Like Clone(), but if an instruction is present in replacement_map, we use
   // the map's value to replace that instruction in the cloned computation.
@@ -305,7 +311,9 @@ class HloComputation {
   std::unique_ptr<HloComputation> CloneWithReplacements(
       std::unordered_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
           replacements,
-      HloModule* module = nullptr, const string& suffix = "clone");
+      HloModule* module = nullptr,
+      HloInstruction::CloneMap* clone_map = nullptr,
+      const string& suffix = "clone");
 
   // Returns true if the given instruction can be removed from the computation.
   // Parameter instructions cannot be removed without violating invariants of
index bb4db89..794f1b4 100644 (file)
@@ -322,11 +322,13 @@ class HloDotDumper {
  public:
   HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
                const DebugOptions& debug_options, bool show_metadata,
-               const HloExecutionProfile* profile, NodeFilter filter)
+               bool show_backend_config, const HloExecutionProfile* profile,
+               NodeFilter filter)
       : computation_(computation),
         label_(label.ToString()),
         debug_options_(debug_options),
         show_metadata_(show_metadata),
+        show_backend_config_(show_backend_config),
         profile_(profile),
         filter_(std::move(filter)) {}
 
@@ -365,6 +367,7 @@ class HloDotDumper {
   string GetInstructionNodeShape(const HloInstruction* instr);
   string GetInstructionNodeLabel(const HloInstruction* instr);
   string GetInstructionNodeMetadata(const HloInstruction* instr);
+  string GetInstructionNodeBackendConfig(const HloInstruction* instr);
   string GetInstructionNodeExtraInfo(const HloInstruction* instr);
   string GetInstructionNodeInlinedOperands(const HloInstruction* instr);
   void AddInstructionIncomingEdges(const HloInstruction* instr);
@@ -393,6 +396,7 @@ class HloDotDumper {
   const string label_;                 // overall name for the graph
   const DebugOptions& debug_options_;
   const bool show_metadata_;
+  const bool show_backend_config_;
   const HloExecutionProfile* profile_;  // may be null
   const NodeFilter filter_;
 
@@ -611,6 +615,10 @@ tooltip = " ";
     if (!extra_info.empty()) {
       StrAppend(&subcomp_label, "<br/>", extra_info);
     }
+    string node_backend_config = GetInstructionNodeBackendConfig(parent_instr);
+    if (!node_backend_config.empty()) {
+      StrAppend(&subcomp_label, "<br/>", node_backend_config);
+    }
 
     bool highlight = filter_.Highlight(parent_instr);
     const char* fillcolor;
@@ -765,6 +773,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
   string node_shape = GetInstructionNodeShape(instr);
   string node_label = GetInstructionNodeLabel(instr);
   string node_metadata = GetInstructionNodeMetadata(instr);
+  string node_backend_config = GetInstructionNodeBackendConfig(instr);
   string extra_info = GetInstructionNodeExtraInfo(instr);
   string inlined_constants = GetInstructionNodeInlinedOperands(instr);
   string trivial_subcomputation = GetInstructionTrivialComputationStr(instr);
@@ -782,8 +791,8 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
   }
   // Build the text that will be displayed inside the node.
   string node_body = node_label;
-  for (const string& s :
-       {trivial_subcomputation, node_metadata, extra_info, inlined_constants}) {
+  for (const string& s : {trivial_subcomputation, node_metadata,
+                          node_backend_config, extra_info, inlined_constants}) {
     if (!s.empty()) {
       StrAppend(&node_body, "<br/>", s);
     }
@@ -1078,6 +1087,15 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) {
   return Join(lines, "<br/>");
 }
 
+string HloDotDumper::GetInstructionNodeBackendConfig(
+    const HloInstruction* instr) {
+  if (!show_backend_config_ || instr->backend_config().empty()) {
+    return "";
+  }
+
+  return StrCat("backend_config=\"", instr->backend_config(), "\"");
+}
+
 string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
   std::vector<string> lines;
 
@@ -1404,7 +1422,7 @@ string ExportGraph(const string& graph,
 string DumpGraph(const HloComputation& computation, const string& label,
                  const DebugOptions& debug_options,
                  const HloExecutionProfile* hlo_execution_profile,
-                 bool show_metadata) {
+                 bool show_metadata, bool show_backend_config) {
   GraphRendererInterface::GraphKind graph_kind;
   string graph;
   if (debug_options.xla_hlo_dump_as_graphdef()) {
@@ -1414,9 +1432,10 @@ string DumpGraph(const HloComputation& computation, const string& label,
                                                           &graph));
     graph_kind = GraphRendererInterface::TF_GRAPHDEF;
   } else {
-    graph = HloDotDumper(&computation, label, debug_options, show_metadata,
-                         hlo_execution_profile, NodeFilter())
-                .Dump();
+    graph =
+        HloDotDumper(&computation, label, debug_options, show_metadata,
+                     show_backend_config, hlo_execution_profile, NodeFilter())
+            .Dump();
     graph_kind = GraphRendererInterface::DOT_GRAPH;
   }
 
@@ -1427,15 +1446,15 @@ string DumpGraph(const HloComputation& computation, const string& label,
 }
 
 string DumpNeighborhoodAround(const HloInstruction& node, int radius,
-                              bool show_metadata) {
+                              bool show_metadata, bool show_backend_config) {
   auto debug_options = node.GetModule()->config().debug_options();
   string label =
       StrCat("Neighborhood of ", radius, " nodes around ", node.name());
   NodeFilter filter = MakeNodeFilter(&node, radius);
-  string graph =
-      HloDotDumper(node.parent(), label, debug_options, show_metadata,
-                   /*profile=*/nullptr, filter)
-          .Dump();
+  string graph = HloDotDumper(node.parent(), label, debug_options,
+                              show_metadata, show_backend_config,
+                              /*profile=*/nullptr, filter)
+                     .Dump();
   return ExportGraph(graph, GraphRendererInterface::DOT_GRAPH, debug_options);
 }
 
index 2704aae..fc8e146 100644 (file)
@@ -56,7 +56,7 @@ string MaybeDumpHloModule(const HloModule& module, const string& label,
 string DumpGraph(const HloComputation& computation, const string& label,
                  const DebugOptions& debug_options,
                  const HloExecutionProfile* hlo_execution_profile = nullptr,
-                 bool show_metadata = false);
+                 bool show_metadata = false, bool show_backend_config = false);
 
 // Like DumpGraph, but renders only nodes "near" the given node in the graph.
 //
@@ -64,7 +64,8 @@ string DumpGraph(const HloComputation& computation, const string& label,
 // (roughly) corresponds to the max distance a node may be from the primary node
 // before it's omitted from the graph.
 string DumpNeighborhoodAround(const HloInstruction& node, int radius,
-                              bool show_metadata = false);
+                              bool show_metadata = false,
+                              bool show_backend_config = false);
 
 // Dumps the HloModule::ToString() as a file into the provided directory path
 // suffixed with the provided label.
index a714d0e..2c73372 100644 (file)
@@ -109,6 +109,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
   instruction->name_ = proto.name();
 
   instruction->metadata_ = proto.metadata();
+  instruction->set_backend_config(proto.backend_config());
   if (proto.has_literal()) {
     TF_ASSIGN_OR_RETURN(instruction->literal_,
                         Literal::CreateFromProto(proto.literal()));
@@ -1231,12 +1232,15 @@ bool HloInstruction::HasSideEffect() const {
 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
     const Shape& shape,
     tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
-    HloModule* module) const {
+    HloModule* module, CloneMap* clone_map) const {
   VLOG(3) << "CloneWithNewOperands:\n  " << ToString();
   VLOG(3) << "  new operands:";
   for (const HloInstruction* new_operand : new_operands) {
     VLOG(3) << "    %" << new_operand->name();
   }
+  if (module == nullptr) {
+    module = GetModule();
+  }
 
   std::unique_ptr<HloInstruction> clone;
 
@@ -1342,7 +1346,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
       break;
     case HloOpcode::kFft:
       CHECK_EQ(new_operands.size(), 1);
-      return CreateFft(shape, new_operands[0], fft_type_, fft_length_);
+      clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_);
+      break;
     case HloOpcode::kCrossReplicaSum:
       clone = CreateCrossReplicaSum(shape, new_operands);
       break;
@@ -1415,9 +1420,15 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
     case HloOpcode::kConstant:
       clone = CreateConstant(literal_->CloneToUnique());
       break;
-    case HloOpcode::kFusion:
-      clone = CloneFusionWithNewOperands(shape, new_operands, module);
+    case HloOpcode::kFusion: {
+      CHECK_NE(module, nullptr);
+      auto new_fused_computation = module->AddEmbeddedComputation(
+          fused_instructions_computation()->Clone("clone", module, clone_map));
+      clone = CreateFusion(/*shape=*/shape, /*fusion_kind=*/fusion_kind(),
+                           /*operands=*/new_operands,
+                           /*fusion_computation=*/new_fused_computation);
       break;
+    }
     case HloOpcode::kParameter:
       clone = CreateParameter(parameter_number_, shape, name_);
       break;
@@ -1481,15 +1492,19 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
   }
   SetupDerivedInstruction(clone.get());
   clone->set_parent(parent_);
+  clone->set_backend_config(backend_config());
+  if (clone_map != nullptr) {
+    InsertOrDie(clone_map, this, clone.get());
+  }
   return clone;
 }
 
 HloInstruction::~HloInstruction() {}
 
-std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
-                                                      HloModule* module) const {
+std::unique_ptr<HloInstruction> HloInstruction::Clone(
+    const string& suffix, HloModule* module, CloneMap* clone_map) const {
   std::unique_ptr<HloInstruction> clone =
-      CloneWithNewOperands(shape_, operands_, module);
+      CloneWithNewOperands(shape_, operands_, module, clone_map);
   if (suffix.empty()) {
     clone->name_ = name();
   } else {
@@ -1526,71 +1541,6 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix,
   return clone;
 }
 
-std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
-    const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
-    HloModule* module) const {
-  CHECK_EQ(opcode_, HloOpcode::kFusion);
-  CHECK(parent() != nullptr);
-
-  auto new_instruction =
-      WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
-  // Add the operands to our new fusion instruction.
-  for (HloInstruction* new_operand : operands) {
-    new_instruction->AppendOperand(new_operand);
-  }
-  // Clone all the fused instructions for the new fusion instruction.
-  HloInstructionMap<HloInstruction*> old_to_new;
-  std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
-  // Create the list of fused parameters by mapping through the cloned,
-  // fused instructions.
-  for (HloInstruction* old_fused_parameter :
-       fused_instructions_computation()->parameter_instructions()) {
-    new_fused_instructions.push_back(
-        old_fused_parameter->Clone("clone", module));
-    HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
-    InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
-  }
-  for (auto old_fused_instruction :
-       fused_instructions_computation()->MakeInstructionPostOrder()) {
-    if (old_fused_instruction->opcode() == HloOpcode::kParameter) {
-      FindOrDie(old_to_new, old_fused_instruction);
-      continue;
-    }
-    std::vector<HloInstruction*> new_operands;
-    for (int64 operand_idx = 0;
-         operand_idx < old_fused_instruction->operand_count(); ++operand_idx) {
-      HloInstruction* old_operand =
-          old_fused_instruction->mutable_operand(operand_idx);
-      new_operands.push_back(FindOrDie(old_to_new, old_operand));
-    }
-    new_fused_instructions.push_back(
-        old_fused_instruction->CloneWithNewOperands(
-            old_fused_instruction->shape(), new_operands, module));
-    HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
-    new_fused_instruction->set_parent(parent_);
-    InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
-  }
-  new_instruction->fusion_kind_ = fusion_kind_;
-  auto computation_builder = HloComputation::Builder(
-      fused_instructions_computation()->name() + ".clone",
-      new_instruction.get());
-  // We iterated the fusion instructions in reverse post order which means
-  // that we must reverse our new list of fusion instructions.
-  for (auto new_fused_instruction_iter = new_fused_instructions.rbegin();
-       new_fused_instruction_iter != new_fused_instructions.rend();
-       ++new_fused_instruction_iter) {
-    computation_builder.AddInstruction(std::move(*new_fused_instruction_iter));
-  }
-  if (module == nullptr) {
-    module = GetModule();
-  }
-  auto fused_root_ = fused_expression_root();
-  new_instruction->called_computations_.push_back(
-      CHECK_NOTNULL(module)->AddEmbeddedComputation(
-          computation_builder.Build(FindOrDie(old_to_new, fused_root_))));
-  return new_instruction;
-}
-
 std::pair<const HloInstruction*, ShapeIndex>
 HloInstruction::LatestNonGteAncestorAndIndex() const {
   const HloInstruction* hlo = this;
@@ -2172,6 +2122,9 @@ string HloInstruction::ToString(const HloPrintOptions& options) const {
        !metadata_.source_file().empty())) {
     StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
   }
+  if (options.print_backend_config() && !backend_config().empty()) {
+    StrAppend(&result, ", backend_config=\"", CEscape(backend_config()), "\"");
+  }
   return result;
 }
 
@@ -2357,6 +2310,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
     extra.push_back(
         StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
   }
+
   return extra;
 }
 
@@ -2386,6 +2340,7 @@ HloInstructionProto HloInstruction::ToProto() const {
   }
 
   *proto.mutable_metadata() = metadata_;
+  proto.set_backend_config(backend_config());
   if (literal_ != nullptr) {
     *proto.mutable_literal() = literal_->ToProto();
   }
@@ -2971,6 +2926,7 @@ Status HloInstruction::AcceptOrdered(
       continue;
     }
 
+    // TODO(b/78350259): Eliminate const laundering.
     HloInstruction* instruction =
         const_cast<HloInstruction*>(const_instruction);
 
index a5e9aec..19c8c11 100644 (file)
@@ -66,6 +66,7 @@ class HloPrintOptions {
       : print_large_constants_(false),
         print_subcomputation_references_(true),
         print_metadata_(true),
+        print_backend_config_(true),
         compact_operands_(false),
         print_operand_shape_(true),
         print_program_shape_(true),
@@ -77,6 +78,7 @@ class HloPrintOptions {
         .set_print_large_constants(true)
         .set_print_subcomputation_references(true)
         .set_print_metadata(false)
+        .set_print_backend_config(false)
         .set_print_operand_shape(false)
         .set_print_program_shape(false)
         .set_print_percent(false);
@@ -99,12 +101,18 @@ class HloPrintOptions {
     return *this;
   }
 
-  // If true, metatdata will be printed.
+  // If true, metadata will be printed.
   HloPrintOptions& set_print_metadata(bool value) {
     print_metadata_ = value;
     return *this;
   }
 
+  // If true, backend_config will be printed.
+  HloPrintOptions& set_print_backend_config(bool value) {
+    print_backend_config_ = value;
+    return *this;
+  }
+
   // If true, operands' shapes will be printed.
   HloPrintOptions& set_print_operand_shape(bool value) {
     print_operand_shape_ = value;
@@ -141,6 +149,7 @@ class HloPrintOptions {
     return print_subcomputation_references_;
   }
   bool print_metadata() const { return print_metadata_; }
+  bool print_backend_config() const { return print_metadata_; }
   bool compact_operands() const { return compact_operands_; }
   bool print_operand_shape() const { return print_operand_shape_; }
   bool print_program_shape() const { return print_program_shape_; }
@@ -151,6 +160,7 @@ class HloPrintOptions {
   bool print_large_constants_;
   bool print_subcomputation_references_;
   bool print_metadata_;
+  bool print_backend_config_;
   bool compact_operands_;
   bool print_operand_shape_;
   bool print_program_shape_;
@@ -643,6 +653,8 @@ class HloInstruction {
   // Detaches an instruction from its operands. That is, remove the instruction
   // from each operand's user set. This should only be called prior to
   // deallocating the instruction.
+  //
+  // TODO(b/78305363): Make this automatic when deleting an instruction.
   void DetachFromOperands();
 
   // Performs a postorder DFS visit using this node as the root. If
@@ -1157,23 +1169,30 @@ class HloInstruction {
   // Precondition: opcode() == HloOpcode::kRng
   RandomDistribution random_distribution() const;
 
+  // See documentation for Clone().
+  using CloneMap = std::unordered_map<const HloInstruction*, HloInstruction*>;
+
   // Clones the HLO instruction. The clone will have the same opcode, shape, and
   // operands. After creation the clone has no uses. "this" (the instruction
   // cloned from) is not changed. Suffix is the string to append to the name of
-  // the instruction to form the name of the cloned instruction.  If the module
-  // pointer is not nullptr, it will be the module where the cloned computations
-  // will be added to (in order to support deep cloning).  Ignores the control
-  // predecessors and successors of this HLO instruction.
+  // the instruction to form the name of the cloned instruction. Ignores the
+  // control predecessors and successors of this HLO instruction.
+  //
+  // If the module pointer is not nullptr, then any cloned computations will be
+  // added to this module in order to support deep cloning. Otherwise the module
+  // of the instruction is used.
+  //
+  // If clone_map is not nullptr, then each original instruction that is cloned
+  // will be inserted and map to its clone. clone_map should not already contain
+  // any of the instructions to clone.
   std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
-                                        HloModule* module = nullptr) const;
+                                        HloModule* module = nullptr,
+                                        CloneMap* clone_map = nullptr) const;
 
-  // Clones the HLO instruction as above but with new shape and operands.  If
-  // the module pointer is not nullptr, it will be the module where the cloned
-  // computations will be added to (in order to support deep cloning).  Ignores
-  // the control predecessors and successors of this HLO instruction.
+  // Clones the HLO instruction as above but with new shape and operands.
   std::unique_ptr<HloInstruction> CloneWithNewOperands(
       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
-      HloModule* module = nullptr) const;
+      HloModule* module = nullptr, CloneMap* clone_map = nullptr) const;
 
   // Returns the computations this instruction directly calls (if any).
   const std::vector<HloComputation*>& called_computations() const {
@@ -1262,6 +1281,19 @@ class HloInstruction {
   // if no id has been assigned yet).
   int unique_id() const { return unique_id_; }
 
+  // Returns the backend-specific configuration for how a backend should compile
+  // this HLO. The meaning of the field is backend specific. Not for use before
+  // or during general HLO optimization, since HLO optimizations do not preserve
+  // this field and they cannot interpret it due to its meaning being backend
+  // specific.
+  //
+  // TODO(b/78194644): Introduce structured configuration format as per
+  // go/xla-heuristics.
+  const string& backend_config() const { return backend_config_; }
+  void set_backend_config(string backend_config) {
+    backend_config_ = std::move(backend_config);
+  }
+
   // Sets the debug metadata for this instruction.
   void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
   const OpMetadata& metadata() const { return metadata_; }
@@ -1283,6 +1315,7 @@ class HloInstruction {
   // Get/Set the number of partitions per outer dimension (in order, starting
   // with outer-most dimension first). Currently used by the parallel cpu
   // backend to partition HLOs into parallel tasks.
+  //
   // TODO(b/62783254) Replace these methods with a more general way to
   // annotate HLOs with backend-specific information.
   const std::vector<int64>& outer_dimension_partitions() const {
@@ -1510,6 +1543,10 @@ class HloInstruction {
   // The string representation of the infeed configuration.
   string infeed_config_;
 
+  // The backend-specific configuration for how a backend should compile this
+  // HLO. See the documentation on backend_config().
+  string backend_config_;
+
   // String identifier for instruction.
   string name_;
 
index c7a7192..5308fb5 100644 (file)
@@ -46,6 +46,18 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config)
       config_(config),
       unique_id_(next_unique_module_id_++) {}
 
+StatusOr<HloInstruction*> HloModule::LaunderConstInstructionFromModule(
+    const HloInstruction* hlo) {
+  if (hlo == nullptr) {
+    return nullptr;
+  }
+
+  TF_RET_CHECK(hlo->GetModule() == this);
+
+  // TODO(b/78350259): Eliminate const laundering.
+  return const_cast<HloInstruction*>(hlo);
+}
+
 HloComputation* HloModule::AddComputationInternal(
     std::unique_ptr<HloComputation> computation, bool is_entry,
     bool uniquify_names) {
index f9674df..1604a72 100644 (file)
@@ -217,6 +217,25 @@ class HloModule {
   // the lifetime of this process.
   int unique_id() const { return unique_id_; }
 
+  // Returns a non-const version of the passed-in const HloInstruction*. This is
+  // safe on the argument that if you have a non-const module, then you can
+  // access all instructions in the module as non-const.
+  //
+  // Returns an error if the passed-in instruction is not from this module,
+  // except that it is allowed to pass in a null pointer.
+  //
+  // TODO(b/78350259): Eliminate const laundering. The argument above is not
+  // reliable since at any time someone could add or discover a way for a
+  // non-const module to transitively contain a const HloInstruction. The
+  // reliable way to do this would be to create a const laundering map from a
+  // module, mapping each encountered HloInstruction to its non-const version
+  // and then look up each instruction in need of laundering in that map, but
+  // this is much more expensive and complicated. This returns a Status instead
+  // of doing a CHECK-failure in part to make it strongly apparent that this is
+  // something that can fail.
+  StatusOr<HloInstruction*> LaunderConstInstructionFromModule(
+      const HloInstruction* hlo);
+
  private:
   HloComputation* AddComputationInternal(
       std::unique_ptr<HloComputation> computation, bool is_entry,
index 8a30cbf..096ebb7 100644 (file)
@@ -116,7 +116,7 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
   // produces no HLO value in the graph.
   if (!ShapeUtil::Compatible(outfeed->outfeed_shape(),
                              outfeed->operand(0)->shape())) {
-    return InvalidArgument(
+    return InternalError(
         "Expected outfeed to have shape compatible with operand's shape %s, "
         "actual shape is %s:\n%s",
         ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(),
@@ -200,7 +200,7 @@ Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
                      transpose->operand(0)->shape(), transpose->dimensions()));
 }
 
-Status ShapeVerifier::HandleParameter(HloInstruction*) {
+Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
   return tensorflow::Status::OK();
 }
 
@@ -410,7 +410,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
               if (fp_type == PRIMITIVE_TYPE_INVALID) {
                 fp_type = subshape.element_type();
               } else if (fp_type != subshape.element_type()) {
-                return FailedPrecondition(
+                return InternalError(
                     "Seen floating point types of different precisions in "
                     "%s, but mixed precision is disallowed.",
                     instruction->ToString().c_str());
@@ -490,7 +490,7 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
       }
   }
   if (!compatible) {
-    return InvalidArgument(
+    return InternalError(
         "Expected instruction to have shape compatible with %s, actual "
         "shape is %s:\n%s",
         ShapeUtil::HumanString(inferred_shape).c_str(),
@@ -541,7 +541,7 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
 Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1,
                                        const HloInstruction* instr2) {
   if (instr1->channel_id() != instr2->channel_id()) {
-    return FailedPrecondition(
+    return InternalError(
         "Expected to have the same channel id, actual channel ids are: %s "
         "(%lld), %s (%lld)",
         instr1->ToString().c_str(), instr1->channel_id(),
@@ -571,22 +571,22 @@ string ComputationsToString(
 Status VerifyHloStructure(HloModule* module) {
   for (const HloComputation* computation : module->computations()) {
     if (computation->parent() == nullptr) {
-      return FailedPrecondition("Computation %s has a null parent pointer",
-                                computation->name().c_str());
+      return InternalError("Computation %s has a null parent pointer",
+                           computation->name().c_str());
     }
     if (computation->parent() != module) {
-      return FailedPrecondition(
+      return InternalError(
           "Computation %s parent() does not point to parent module",
           computation->name().c_str());
     }
 
     for (const HloInstruction* instruction : computation->instructions()) {
       if (instruction->parent() == nullptr) {
-        return FailedPrecondition("Instruction %s has a null parent pointer",
-                                  instruction->name().c_str());
+        return InternalError("Instruction %s has a null parent pointer",
+                             instruction->name().c_str());
       }
       if (instruction->parent() != computation) {
-        return FailedPrecondition(
+        return InternalError(
             "Instruction %s parent() does not point to parent computation",
             instruction->name().c_str());
       }
@@ -602,7 +602,7 @@ Status VerifyHloStructure(HloModule* module) {
       for (int i = 0; i < instruction->operand_count(); ++i) {
         const HloInstruction* operand = instruction->operand(i);
         if (operand->parent() != instruction->parent()) {
-          return FailedPrecondition(
+          return InternalError(
               "Operand %d (%s) of instruction %s is in a different "
               "computation: %s vs %s",
               i, operand->name().c_str(), instruction->name().c_str(),
@@ -619,7 +619,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
   // The parent fusion instruction of the fusion computation must be 'fusion'.
   HloComputation* fused_computation = fusion->fused_instructions_computation();
   if (fusion != fused_computation->FusionInstruction()) {
-    return FailedPrecondition(
+    return InternalError(
         "Instruction of fused computation does not match expected instruction "
         "%s.",
         fusion->ToString().c_str());
@@ -635,37 +635,37 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
   for (auto* instruction : fused_computation->instructions()) {
     if (fused_root == instruction) {
       if (root_owned) {
-        return FailedPrecondition("Root appears more than once in %s.",
-                                  fusion->ToString().c_str());
+        return InternalError("Root appears more than once in %s.",
+                             fusion->ToString().c_str());
       }
       root_owned = true;
     }
     for (int i = 0; i < fused_parameters.size(); ++i) {
       if (fused_parameters[i] == instruction) {
         if (parameter_owned[i]) {
-          return FailedPrecondition("Parameter appears more than once in %s.",
-                                    fusion->ToString().c_str());
+          return InternalError("Parameter appears more than once in %s.",
+                               fusion->ToString().c_str());
         }
         parameter_owned[i] = true;
       }
     }
   }
   if (!root_owned) {
-    return FailedPrecondition("Root not found in computation of %s.",
-                              fusion->ToString().c_str());
+    return InternalError("Root not found in computation of %s.",
+                         fusion->ToString().c_str());
   }
   // Make sure all the parameter_owned entries are set
   for (int i = 0; i < parameter_owned.size(); i++) {
     if (!parameter_owned[i]) {
-      return FailedPrecondition("Parameter %d not found in computation of %s.",
-                                i, fusion->ToString().c_str());
+      return InternalError("Parameter %d not found in computation of %s.", i,
+                           fusion->ToString().c_str());
     }
   }
 
   // Fused root must have no users.
   if (fused_root->user_count() != 0) {
-    return FailedPrecondition("Root of %s may not have users.",
-                              fusion->ToString().c_str());
+    return InternalError("Root of %s may not have users.",
+                         fusion->ToString().c_str());
   }
 
   // All uses of fused instructions must be in the fusion computation, and every
@@ -674,13 +674,13 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
        fusion->fused_instructions_computation()->instructions()) {
     if (instruction != fused_root) {
       if (instruction->user_count() == 0) {
-        return FailedPrecondition(
-            "Non-root instruction %s in %s must have users.",
-            instruction->ToString().c_str(), fusion->ToString().c_str());
+        return InternalError("Non-root instruction %s in %s must have users.",
+                             instruction->ToString().c_str(),
+                             fusion->ToString().c_str());
       }
       for (auto& user : instruction->users()) {
         if (fused_computation != user->parent()) {
-          return FailedPrecondition(
+          return InternalError(
               "Non-root instruction %s in %s may not have external users.",
               instruction->ToString().c_str(), fusion->ToString().c_str());
         }
@@ -695,34 +695,33 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
   for (auto fused_param : fused_parameters) {
     int64 param_no = fused_param->parameter_number();
     if (param_no < 0) {
-      return FailedPrecondition(
-          "Unexpected negative parameter number %lld in %s.", param_no,
-          fusion->ToString().c_str());
+      return InternalError("Unexpected negative parameter number %lld in %s.",
+                           param_no, fusion->ToString().c_str());
     }
     if (param_no >= fused_parameters.size()) {
-      return FailedPrecondition(
+      return InternalError(
           "Unexpected parameter number %lld in %s: higher then number of "
           "parameters %lu.",
           param_no, fusion->ToString().c_str(), fused_parameters.size());
     }
     if (parameter_numbers[param_no]) {
-      return FailedPrecondition(
+      return InternalError(
           "Did not expect parameter number %lld more than once in %s.",
           param_no, fusion->ToString().c_str());
     }
     parameter_numbers[param_no] = true;
     if (!ShapeUtil::Compatible(fused_param->shape(),
                                fusion->operand(param_no)->shape())) {
-      return FailedPrecondition(
+      return InternalError(
           "Shape mismatch between parameter number %lld and its operand in %s.",
           param_no, fusion->ToString().c_str());
     }
   }
-  // Make sure all the parameter_numbers entries were seen
+  // Make sure all the parameter_numbers entries were seen.
   for (int i = 0; i < parameter_numbers.size(); i++) {
     if (!parameter_numbers[i]) {
-      return FailedPrecondition("Did not see parameter number %d in %s.", i,
-                                fusion->ToString().c_str());
+      return InternalError("Did not see parameter number %d in %s.", i,
+                           fusion->ToString().c_str());
     }
   }
 
index cccbce5..0e1387c 100644 (file)
@@ -13,13 +13,10 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-// StatusOr<T> is the union of a Status object and a T
-// object. StatusOr models the concept of an object that is either a
-// usable value, or an error Status explaining why such a value is
-// not present. To this end, StatusOr<T> does not allow its Status
-// value to be Status::OK. Furthermore, the value of a StatusOr<T*>
-// must not be null. This is enforced by a debug check in most cases,
-// but even when it is not, clients must not set the value to null.
+// StatusOr<T> is the union of a Status object and a T object. StatusOr models
+// the concept of an object that is either a value, or an error Status
+// explaining why such a value is not present. To this end, StatusOr<T> does not
+// allow its Status value to be Status::OK.
 //
 // The primary use-case for StatusOr<T> is as the return value of a
 // function which may fail.
index f9d2594..7d76370 100644 (file)
@@ -75,6 +75,14 @@ TEST(StatusOr, ElementType) {
   static_assert(std::is_same<StatusOr<char>::element_type, char>(), "");
 }
 
+TEST(StatusOr, NullPointerStatusOr) {
+  // As a very special case, null-plain-pointer StatusOr used to be an
+  // error. Test that it no longer is.
+  StatusOr<int*> null_status(nullptr);
+  EXPECT_TRUE(null_status.ok());
+  EXPECT_EQ(null_status.ValueOrDie(), nullptr);
+}
+
 TEST(StatusOr, TestNoDefaultConstructorInitialization) {
   // Explicitly initialize it with an error code.
   StatusOr<NoDefaultConstructor> statusor(tensorflow::errors::Cancelled(""));
index 40dc073..156a06c 100644 (file)
@@ -440,6 +440,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
   optional<OpMetadata> metadata;
   attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
 
+  optional<string> backend_config;
+  attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
+                             &backend_config};
+
   HloInstruction* instruction;
   switch (opcode) {
     case HloOpcode::kParameter: {
@@ -1094,8 +1098,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
 
   instruction->set_name(name);
 
-  // Add common attrs (sharding, control predecessors) to the instruction, if
-  // they were seen.
+  // Add shared attributes like metadata to the instruction, if they were seen.
   if (sharding) {
     instruction->set_sharding(
         HloSharding::FromProto(sharding.value()).ValueOrDie());
@@ -1112,6 +1115,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
   if (metadata) {
     instruction->set_metadata(*metadata);
   }
+  if (backend_config) {
+    instruction->set_backend_config(std::move(*backend_config));
+  }
   return AddInstruction(name, instruction, name_loc);
 }  // NOLINT(readability/fn_size)
 
index d38d890..e100d8c 100644 (file)
@@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
 R"(HloModule constant_pred_module
 
 ENTRY %constant_pred () -> pred[] {
-  ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}
+  ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
 }
 
 )"
@@ -81,13 +81,14 @@ ENTRY %constant_s32 () -> s32[] {
 
 )"
 },
-// f32 constant, but the value is not a decimal
+// f32 constant, but the value is not a decimal and there is a backend
+// configuration
 {
 "ConstantF32",
 R"(HloModule ConstantF32_module
 
 ENTRY %ConstantF32.v4 () -> f32[] {
-  ROOT %constant = f32[] constant(42)
+  ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
 }
 
 )"
@@ -1013,6 +1014,19 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] {
   // but the constant names will not be exactly the same.
 }
 
+TEST_F(HloParserTest, ConfigurationField) {
+  const string original = R"(HloModule AModule
+ENTRY %configuration_test() -> s32[] {
+  %constant = s32[] constant(42), backend_config="foo bar"
+})";
+  auto result = Parse(original);
+  TF_ASSERT_OK(result.status());
+  EXPECT_EQ("foo bar", result.ValueOrDie()
+                           ->entry_computation()
+                           ->root_instruction()
+                           ->backend_config());
+}
+
 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
   const string original = R"(HloModule some_2_module
 
@@ -1092,7 +1106,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2
   %input = f32[1,2,1]{2,1,0} parameter(0)
   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
   %filter = f32[1,1,1]{2,1,0} parameter(1)
-  ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
+  ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
 }
 
 )";