Make IndexedArrayAnalysis behave well around StatusOr
authorSanjoy Das <sanjoy@google.com>
Tue, 29 May 2018 05:16:46 +0000 (22:16 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 29 May 2018 05:19:19 +0000 (22:19 -0700)
PiperOrigin-RevId: 198348355

tensorflow/compiler/xla/service/indexed_array_analysis.cc
tensorflow/compiler/xla/service/indexed_array_analysis.h
tensorflow/compiler/xla/service/indexed_array_analysis_test.cc

index 5d870f9..21af9a6 100644 (file)
@@ -33,8 +33,6 @@ using tensorflow::gtl::ArraySlice;
 using tensorflow::str_util::Join;
 }  // namespace
 
-// TODO(sanjoy): Make this pass StatusOr safe.
-
 string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
   switch (root->kind()) {
     case Array::kUnknown: {
@@ -69,18 +67,18 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) {
   }
 }
 
-Analysis::Array* IndexedArrayAnalysis::GetArrayFor(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::GetArrayFor(
     const HloInstruction* instr) {
   auto it = cache_.find(instr);
   if (it != cache_.end()) {
     return it->second;
   }
 
-  TraverseAndPopulateCache(instr);
+  TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr));
   return FindOrDie(cache_, instr);
 }
 
-void IndexedArrayAnalysis::TraverseAndPopulateCache(
+Status IndexedArrayAnalysis::TraverseAndPopulateCache(
     const HloInstruction* root) {
   // Depth first search over the DAG, invoking ComputeArrayFor in post order.
   // The HLO instructions already in the cache are considered leaves.
@@ -116,32 +114,42 @@ void IndexedArrayAnalysis::TraverseAndPopulateCache(
 
       case kVisited:
         stack.pop_back();
-        InsertOrDie(&cache_, instr, ComputeArrayFor(instr));
+        TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr));
+        InsertOrDie(&cache_, instr, array);
         break;
     }
   } while (!stack.empty());
+
+  return Status::OK();
 }
 
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayFor(
     const HloInstruction* instr) {
   Array* computed_array;
   if (instr->IsElementwise() && instr->operand_count() == 1) {
-    computed_array = ComputeArrayForElementwiseUnaryOp(
-        instr, FindOrDie(cache_, instr->operand(0)));
+    TF_ASSIGN_OR_RETURN(computed_array,
+                        ComputeArrayForElementwiseUnaryOp(
+                            instr, FindOrDie(cache_, instr->operand(0))));
   } else if (instr->IsElementwise() && instr->operand_count() == 2) {
-    computed_array = ComputeArrayForElementwiseBinaryOp(
-        instr, FindOrDie(cache_, instr->operand(0)),
-        FindOrDie(cache_, instr->operand(1)));
+    TF_ASSIGN_OR_RETURN(computed_array,
+                        ComputeArrayForElementwiseBinaryOp(
+                            instr, FindOrDie(cache_, instr->operand(0)),
+                            FindOrDie(cache_, instr->operand(1))));
   } else if (instr->opcode() == HloOpcode::kConstant) {
-    computed_array = ComputeArrayForConstant(instr->literal());
+    TF_ASSIGN_OR_RETURN(computed_array,
+                        ComputeArrayForConstant(instr->literal()));
   } else if (instr->opcode() == HloOpcode::kGather) {
-    computed_array = ComputeArrayForGather(
-        instr->shape(), instr->gather_dimension_numbers(),
-        instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)),
-        FindOrDie(cache_, instr->operand(1)));
+    TF_ASSIGN_OR_RETURN(
+        computed_array,
+        ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(),
+                              instr->gather_window_bounds(),
+                              FindOrDie(cache_, instr->operand(0)),
+                              FindOrDie(cache_, instr->operand(1))));
   } else if (instr->opcode() == HloOpcode::kReshape) {
-    computed_array = ComputeArrayForReshape(
-        instr->shape(), FindOrDie(cache_, instr->operand(0)));
+    TF_ASSIGN_OR_RETURN(
+        computed_array,
+        ComputeArrayForReshape(instr->shape(),
+                               FindOrDie(cache_, instr->operand(0))));
   } else {
     computed_array = nullptr;
   }
@@ -153,12 +161,12 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor(
   return computed_array;
 }
 
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForConstant(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForConstant(
     const Literal& literal) {
   return Construct<ConstantArray>(&literal);
 }
 
-ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
+StatusOr<ScalarIndexedArray*> IndexedArrayAnalysis::FoldGatherOfGather(
     ScalarIndexedArray* source, Array* indices, int64 source_dim,
     tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape) {
   // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)).
@@ -224,7 +232,7 @@ ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather(
                                      std::move(shape));
 }
 
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForGather(
     const Shape& shape, const GatherDimensionNumbers& dim_numbers,
     tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
     Array* indices) {
@@ -397,7 +405,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice<int64> operand_shape,
 
 };  // namespace
 
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape(
+StatusOr<Analysis::Array*> IndexedArrayAnalysis::ComputeArrayForReshape(
     const Shape& shape, Array* operand) {
   auto* scalar_indexed = dynamic_cast<ScalarIndexedConstantArray*>(operand);
   if (!scalar_indexed) {
@@ -541,10 +549,12 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape(
               std::back_inserter(output_dims_for_new_scalar_indexed_node),
               map_passthrough_operand_dim_to_result_dim);
 
-  Array* new_scalar_indexed_source = ComputeArrayForConstant(
-      *TakeOwnership(scalar_indexed->literal()
-                         .Reshape(new_scalar_indexed_source_shape)
-                         .ValueOrDie()));
+  TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal,
+                      TakeOwnership(scalar_indexed->literal().Reshape(
+                          new_scalar_indexed_source_shape)));
+  TF_ASSIGN_OR_RETURN(
+      Array * new_scalar_indexed_source,
+      ComputeArrayForConstant(*new_scalar_indexed_source_literal));
 
   return ConstructScalarIndexedArray(
       new_scalar_indexed_source, scalar_indexed->indices(),
@@ -552,7 +562,8 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape(
       output_dims_for_new_scalar_indexed_node, shape);
 }
 
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
     const HloInstruction* instr, Array* lhs, Array* rhs) {
   // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices))
   //          => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices)
@@ -642,28 +653,25 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
 
   // inner_broadcast_result is the Broadcast'(Const0) bit in
   // BinaryOp(Broadcast'(Const0), Const1)
-  std::unique_ptr<Literal> inner_broadcast_result =
-      broadcast_const_operand->literal()
-          .Broadcast(scalar_indexed_const->source()->shape(),
-                     new_inner_broadcast_dims)
-          .ConsumeValueOrDie();
+  TF_ASSIGN_OR_RETURN(
+      std::unique_ptr<Literal> inner_broadcast_result,
+      broadcast_const_operand->literal().Broadcast(
+          scalar_indexed_const->source()->shape(), new_inner_broadcast_dims));
 
   // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1)
   const Literal* literal_for_new_source;
   if (lhs_is_indexed) {
-    literal_for_new_source =
-        TakeOwnership(HloEvaluator{}
-                          .EvaluateElementwiseBinaryOp(
-                              instr->opcode(), scalar_indexed_const->literal(),
-                              *inner_broadcast_result)
-                          .ConsumeValueOrDie());
+    TF_ASSIGN_OR_RETURN(
+        literal_for_new_source,
+        TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
+            instr->opcode(), scalar_indexed_const->literal(),
+            *inner_broadcast_result)));
   } else {
-    literal_for_new_source =
-        TakeOwnership(HloEvaluator{}
-                          .EvaluateElementwiseBinaryOp(
-                              instr->opcode(), *inner_broadcast_result,
-                              scalar_indexed_const->literal())
-                          .ConsumeValueOrDie());
+    TF_ASSIGN_OR_RETURN(
+        literal_for_new_source,
+        TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp(
+            instr->opcode(), *inner_broadcast_result,
+            scalar_indexed_const->literal())));
   }
 
   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
@@ -675,7 +683,8 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp(
       scalar_indexed_const->shape());
 }
 
-Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
+StatusOr<Analysis::Array*>
+IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
     const HloInstruction* instr, Array* operand) {
   auto* scalar_indexed_const =
       dynamic_cast<ScalarIndexedConstantArray*>(operand);
@@ -686,11 +695,9 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp(
   // Fold UnaryOp(ScalarIndexed(Const, Indices))
   //   => ScalarIndexed(UnaryOp(Const), Indices)
 
-  Literal* literal_for_new_source =
-      TakeOwnership(HloEvaluator{}
-                        .EvaluateElementwiseUnaryOp(
-                            instr->opcode(), scalar_indexed_const->literal())
-                        .ConsumeValueOrDie());
+  TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source,
+                      TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp(
+                          instr->opcode(), scalar_indexed_const->literal())));
   ConstantArray* new_source = Construct<ConstantArray>(literal_for_new_source);
   return Construct<ScalarIndexedConstantArray>(
       new_source, scalar_indexed_const->indices(),
@@ -712,7 +719,7 @@ StatusOr<bool> IndexedArrayAnalysisPrinterPass::Run(HloModule* module) {
   IndexedArrayAnalysis analysis;
   for (auto* computation : module->MakeNonfusionComputations()) {
     for (auto* instr : computation->instructions()) {
-      auto* t = analysis.GetArrayFor(instr);
+      TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr));
       if (!dynamic_cast<UnknownArray*>(t) && !dynamic_cast<ConstantArray*>(t)) {
         VLOG(2) << instr->ToString() << "   ->   " << analysis.ToString(t);
       }
index 8c1f616..561832a 100644 (file)
@@ -220,7 +220,7 @@ class IndexedArrayAnalysis {
   // NB!  By inspecting the implementation, you may be able to infer a stronger
   // caching guarantee than what is mentioned above.  Nevertheless, what is
   // stated above is the contract.
-  Array* GetArrayFor(const HloInstruction* instr);
+  StatusOr<Array*> GetArrayFor(const HloInstruction* instr);
 
   // Pretty-prints the expression rooted at `root`.
   string ToString(Array* root, bool print_constants = false);
@@ -228,18 +228,18 @@ class IndexedArrayAnalysis {
  private:
   // Helper function that ensures that every HLO instruction that is
   // transitively used by `root` has an entry in `cache_`.
-  void TraverseAndPopulateCache(const HloInstruction* root);
+  Status TraverseAndPopulateCache(const HloInstruction* root);
 
   // Creates an Array instance for `instr` under the assumption that all
   // operations of `instr` are present in `cache_`.
-  Array* ComputeArrayFor(const HloInstruction* instr);
+  StatusOr<Array*> ComputeArrayFor(const HloInstruction* instr);
 
-  Array* ComputeArrayForConstant(const Literal& literal);
+  StatusOr<Array*> ComputeArrayForConstant(const Literal& literal);
 
-  Array* ComputeArrayForGather(const Shape& shape,
-                               const GatherDimensionNumbers& dim_numbers,
-                               tensorflow::gtl::ArraySlice<int64> window_bounds,
-                               Array* source, Array* indices);
+  StatusOr<Array*> ComputeArrayForGather(
+      const Shape& shape, const GatherDimensionNumbers& dim_numbers,
+      tensorflow::gtl::ArraySlice<int64> window_bounds, Array* source,
+      Array* indices);
 
   // This tries to fold a ScalarIndexedArray which has another
   // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a
@@ -262,16 +262,16 @@ class IndexedArrayAnalysis {
   //
   //    I2 = [I0[i]  for i in I1]
   //    G1 = [Arr[i] for i in I2]
-  ScalarIndexedArray* FoldGatherOfGather(
+  StatusOr<ScalarIndexedArray*> FoldGatherOfGather(
       ScalarIndexedArray* source, Array* indices, int64 source_dim,
       tensorflow::gtl::ArraySlice<int64> output_dims, Shape shape);
 
-  Array* ComputeArrayForReshape(const Shape& shape, Array* operand);
+  StatusOr<Array*> ComputeArrayForReshape(const Shape& shape, Array* operand);
 
-  Array* ComputeArrayForElementwiseBinaryOp(const HloInstruction* instr,
-                                            Array* lhs, Array* rhs);
-  Array* ComputeArrayForElementwiseUnaryOp(const HloInstruction* instr,
-                                           Array* operand);
+  StatusOr<Array*> ComputeArrayForElementwiseBinaryOp(
+      const HloInstruction* instr, Array* lhs, Array* rhs);
+  StatusOr<Array*> ComputeArrayForElementwiseUnaryOp(
+      const HloInstruction* instr, Array* operand);
 
   template <typename T, typename... Args>
   T* Construct(Args&&... args) {
@@ -299,6 +299,14 @@ class IndexedArrayAnalysis {
     return owned_literals_.back().get();
   }
 
+  StatusOr<Literal*> TakeOwnership(
+      StatusOr<std::unique_ptr<Literal>> literal_or_error) {
+    TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
+                        std::move(literal_or_error));
+    owned_literals_.push_back(std::move(literal));
+    return owned_literals_.back().get();
+  }
+
   std::vector<std::unique_ptr<Array>> owned_tensors_;
   std::vector<std::unique_ptr<Literal>> owned_literals_;
   tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
index 76e7e70..68f247b 100644 (file)
@@ -40,12 +40,14 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
     IndexedArrayAnalysis indexed_tensor_analysis;
     ParseAndVerifyModule(hlo_text);
 
-    string result = indexed_tensor_analysis.ToString(
+    TF_ASSERT_OK_AND_ASSIGN(
+        IndexedArrayAnalysis::Array* const array_result,
         indexed_tensor_analysis.GetArrayFor(
-            module().entry_computation()->root_instruction()),
-        print_constants);
-    LOG(INFO) << result;
-    ASSERT_EQ(result, root_expression);
+            module().entry_computation()->root_instruction()));
+    string string_result =
+        indexed_tensor_analysis.ToString(array_result, print_constants);
+    LOG(INFO) << string_result;
+    ASSERT_EQ(string_result, root_expression);
   }
 };