From e62f3e5ff68aad1ddef2b581b98a90125e740ddd Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Mon, 28 May 2018 22:16:46 -0700 Subject: [PATCH] Make IndexedArrayAnalysis behave well around StatusOr PiperOrigin-RevId: 198348355 --- .../compiler/xla/service/indexed_array_analysis.cc | 111 +++++++++++---------- .../compiler/xla/service/indexed_array_analysis.h | 36 ++++--- .../xla/service/indexed_array_analysis_test.cc | 12 ++- 3 files changed, 88 insertions(+), 71 deletions(-) diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc index 5d870f9..21af9a6 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc @@ -33,8 +33,6 @@ using tensorflow::gtl::ArraySlice; using tensorflow::str_util::Join; } // namespace -// TODO(sanjoy): Make this pass StatusOr safe. - string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { switch (root->kind()) { case Array::kUnknown: { @@ -69,18 +67,18 @@ string IndexedArrayAnalysis::ToString(Array* root, bool print_constants) { } } -Analysis::Array* IndexedArrayAnalysis::GetArrayFor( +StatusOr IndexedArrayAnalysis::GetArrayFor( const HloInstruction* instr) { auto it = cache_.find(instr); if (it != cache_.end()) { return it->second; } - TraverseAndPopulateCache(instr); + TF_RETURN_IF_ERROR(TraverseAndPopulateCache(instr)); return FindOrDie(cache_, instr); } -void IndexedArrayAnalysis::TraverseAndPopulateCache( +Status IndexedArrayAnalysis::TraverseAndPopulateCache( const HloInstruction* root) { // Depth first search over the DAG, invoking ComputeArrayFor in post order. // The HLO instructions already in the cache are considered leaves. @@ -116,32 +114,42 @@ void IndexedArrayAnalysis::TraverseAndPopulateCache( case kVisited: stack.pop_back(); - InsertOrDie(&cache_, instr, ComputeArrayFor(instr)); + TF_ASSIGN_OR_RETURN(Array * array, ComputeArrayFor(instr)); + InsertOrDie(&cache_, instr, array); break; } } while (!stack.empty()); + + return Status::OK(); } -Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor( +StatusOr IndexedArrayAnalysis::ComputeArrayFor( const HloInstruction* instr) { Array* computed_array; if (instr->IsElementwise() && instr->operand_count() == 1) { - computed_array = ComputeArrayForElementwiseUnaryOp( - instr, FindOrDie(cache_, instr->operand(0))); + TF_ASSIGN_OR_RETURN(computed_array, + ComputeArrayForElementwiseUnaryOp( + instr, FindOrDie(cache_, instr->operand(0)))); } else if (instr->IsElementwise() && instr->operand_count() == 2) { - computed_array = ComputeArrayForElementwiseBinaryOp( - instr, FindOrDie(cache_, instr->operand(0)), - FindOrDie(cache_, instr->operand(1))); + TF_ASSIGN_OR_RETURN(computed_array, + ComputeArrayForElementwiseBinaryOp( + instr, FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); } else if (instr->opcode() == HloOpcode::kConstant) { - computed_array = ComputeArrayForConstant(instr->literal()); + TF_ASSIGN_OR_RETURN(computed_array, + ComputeArrayForConstant(instr->literal())); } else if (instr->opcode() == HloOpcode::kGather) { - computed_array = ComputeArrayForGather( - instr->shape(), instr->gather_dimension_numbers(), - instr->gather_window_bounds(), FindOrDie(cache_, instr->operand(0)), - FindOrDie(cache_, instr->operand(1))); + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForGather(instr->shape(), instr->gather_dimension_numbers(), + instr->gather_window_bounds(), + FindOrDie(cache_, instr->operand(0)), + FindOrDie(cache_, instr->operand(1)))); } else if (instr->opcode() == HloOpcode::kReshape) { - computed_array = ComputeArrayForReshape( - instr->shape(), FindOrDie(cache_, instr->operand(0))); + TF_ASSIGN_OR_RETURN( + computed_array, + ComputeArrayForReshape(instr->shape(), + FindOrDie(cache_, instr->operand(0)))); } else { computed_array = nullptr; } @@ -153,12 +161,12 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayFor( return computed_array; } -Analysis::Array* IndexedArrayAnalysis::ComputeArrayForConstant( +StatusOr IndexedArrayAnalysis::ComputeArrayForConstant( const Literal& literal) { return Construct(&literal); } -ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather( +StatusOr IndexedArrayAnalysis::FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, tensorflow::gtl::ArraySlice output_dims, Shape shape) { // We want to transform Gather(Gather(A, X), Y) => Gather(A, Gather(X, Y)). @@ -224,7 +232,7 @@ ScalarIndexedArray* IndexedArrayAnalysis::FoldGatherOfGather( std::move(shape)); } -Analysis::Array* IndexedArrayAnalysis::ComputeArrayForGather( +StatusOr IndexedArrayAnalysis::ComputeArrayForGather( const Shape& shape, const GatherDimensionNumbers& dim_numbers, tensorflow::gtl::ArraySlice window_bounds, Array* source, Array* indices) { @@ -397,7 +405,7 @@ int64 FindSourcePositionForPassthroughResultDim(ArraySlice operand_shape, }; // namespace -Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape( +StatusOr IndexedArrayAnalysis::ComputeArrayForReshape( const Shape& shape, Array* operand) { auto* scalar_indexed = dynamic_cast(operand); if (!scalar_indexed) { @@ -541,10 +549,12 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape( std::back_inserter(output_dims_for_new_scalar_indexed_node), map_passthrough_operand_dim_to_result_dim); - Array* new_scalar_indexed_source = ComputeArrayForConstant( - *TakeOwnership(scalar_indexed->literal() - .Reshape(new_scalar_indexed_source_shape) - .ValueOrDie())); + TF_ASSIGN_OR_RETURN(const Literal* new_scalar_indexed_source_literal, + TakeOwnership(scalar_indexed->literal().Reshape( + new_scalar_indexed_source_shape))); + TF_ASSIGN_OR_RETURN( + Array * new_scalar_indexed_source, + ComputeArrayForConstant(*new_scalar_indexed_source_literal)); return ConstructScalarIndexedArray( new_scalar_indexed_source, scalar_indexed->indices(), @@ -552,7 +562,8 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForReshape( output_dims_for_new_scalar_indexed_node, shape); } -Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp( +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp( const HloInstruction* instr, Array* lhs, Array* rhs) { // Try to fold BinaryOp(Broadcast(Const0), ScalarIndexed(Const1, Indices)) // => ScalarIndexed(BinaryOp(Broadcast'(Const0), Const1), Indices) @@ -642,28 +653,25 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp( // inner_broadcast_result is the Broadcast'(Const0) bit in // BinaryOp(Broadcast'(Const0), Const1) - std::unique_ptr inner_broadcast_result = - broadcast_const_operand->literal() - .Broadcast(scalar_indexed_const->source()->shape(), - new_inner_broadcast_dims) - .ConsumeValueOrDie(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr inner_broadcast_result, + broadcast_const_operand->literal().Broadcast( + scalar_indexed_const->source()->shape(), new_inner_broadcast_dims)); // literal_for_new_source is BinaryOp(Broadcast'(Const0), Const1) const Literal* literal_for_new_source; if (lhs_is_indexed) { - literal_for_new_source = - TakeOwnership(HloEvaluator{} - .EvaluateElementwiseBinaryOp( - instr->opcode(), scalar_indexed_const->literal(), - *inner_broadcast_result) - .ConsumeValueOrDie()); + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + instr->opcode(), scalar_indexed_const->literal(), + *inner_broadcast_result))); } else { - literal_for_new_source = - TakeOwnership(HloEvaluator{} - .EvaluateElementwiseBinaryOp( - instr->opcode(), *inner_broadcast_result, - scalar_indexed_const->literal()) - .ConsumeValueOrDie()); + TF_ASSIGN_OR_RETURN( + literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseBinaryOp( + instr->opcode(), *inner_broadcast_result, + scalar_indexed_const->literal()))); } ConstantArray* new_source = Construct(literal_for_new_source); @@ -675,7 +683,8 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseBinaryOp( scalar_indexed_const->shape()); } -Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp( +StatusOr +IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp( const HloInstruction* instr, Array* operand) { auto* scalar_indexed_const = dynamic_cast(operand); @@ -686,11 +695,9 @@ Analysis::Array* IndexedArrayAnalysis::ComputeArrayForElementwiseUnaryOp( // Fold UnaryOp(ScalarIndexed(Const, Indices)) // => ScalarIndexed(UnaryOp(Const), Indices) - Literal* literal_for_new_source = - TakeOwnership(HloEvaluator{} - .EvaluateElementwiseUnaryOp( - instr->opcode(), scalar_indexed_const->literal()) - .ConsumeValueOrDie()); + TF_ASSIGN_OR_RETURN(Literal * literal_for_new_source, + TakeOwnership(HloEvaluator{}.EvaluateElementwiseUnaryOp( + instr->opcode(), scalar_indexed_const->literal()))); ConstantArray* new_source = Construct(literal_for_new_source); return Construct( new_source, scalar_indexed_const->indices(), @@ -712,7 +719,7 @@ StatusOr IndexedArrayAnalysisPrinterPass::Run(HloModule* module) { IndexedArrayAnalysis analysis; for (auto* computation : module->MakeNonfusionComputations()) { for (auto* instr : computation->instructions()) { - auto* t = analysis.GetArrayFor(instr); + TF_ASSIGN_OR_RETURN(Analysis::Array * t, analysis.GetArrayFor(instr)); if (!dynamic_cast(t) && !dynamic_cast(t)) { VLOG(2) << instr->ToString() << " -> " << analysis.ToString(t); } diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h index 8c1f616..561832a 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis.h +++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h @@ -220,7 +220,7 @@ class IndexedArrayAnalysis { // NB! By inspecting the implementation, you may be able to infer a stronger // caching guarantee than what is mentioned above. Nevertheless, what is // stated above is the contract. - Array* GetArrayFor(const HloInstruction* instr); + StatusOr GetArrayFor(const HloInstruction* instr); // Pretty-prints the expression rooted at `root`. string ToString(Array* root, bool print_constants = false); @@ -228,18 +228,18 @@ class IndexedArrayAnalysis { private: // Helper function that ensures that every HLO instruction that is // transitively used by `root` has an entry in `cache_`. - void TraverseAndPopulateCache(const HloInstruction* root); + Status TraverseAndPopulateCache(const HloInstruction* root); // Creates an Array instance for `instr` under the assumption that all // operations of `instr` are present in `cache_`. - Array* ComputeArrayFor(const HloInstruction* instr); + StatusOr ComputeArrayFor(const HloInstruction* instr); - Array* ComputeArrayForConstant(const Literal& literal); + StatusOr ComputeArrayForConstant(const Literal& literal); - Array* ComputeArrayForGather(const Shape& shape, - const GatherDimensionNumbers& dim_numbers, - tensorflow::gtl::ArraySlice window_bounds, - Array* source, Array* indices); + StatusOr ComputeArrayForGather( + const Shape& shape, const GatherDimensionNumbers& dim_numbers, + tensorflow::gtl::ArraySlice window_bounds, Array* source, + Array* indices); // This tries to fold a ScalarIndexedArray which has another // ScalarIndexedArray as a source into a ScalarIndexedArray that instead has a @@ -262,16 +262,16 @@ class IndexedArrayAnalysis { // // I2 = [I0[i] for i in I1] // G1 = [Arr[i] for i in I2] - ScalarIndexedArray* FoldGatherOfGather( + StatusOr FoldGatherOfGather( ScalarIndexedArray* source, Array* indices, int64 source_dim, tensorflow::gtl::ArraySlice output_dims, Shape shape); - Array* ComputeArrayForReshape(const Shape& shape, Array* operand); + StatusOr ComputeArrayForReshape(const Shape& shape, Array* operand); - Array* ComputeArrayForElementwiseBinaryOp(const HloInstruction* instr, - Array* lhs, Array* rhs); - Array* ComputeArrayForElementwiseUnaryOp(const HloInstruction* instr, - Array* operand); + StatusOr ComputeArrayForElementwiseBinaryOp( + const HloInstruction* instr, Array* lhs, Array* rhs); + StatusOr ComputeArrayForElementwiseUnaryOp( + const HloInstruction* instr, Array* operand); template T* Construct(Args&&... args) { @@ -299,6 +299,14 @@ class IndexedArrayAnalysis { return owned_literals_.back().get(); } + StatusOr TakeOwnership( + StatusOr> literal_or_error) { + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + std::move(literal_or_error)); + owned_literals_.push_back(std::move(literal)); + return owned_literals_.back().get(); + } + std::vector> owned_tensors_; std::vector> owned_literals_; tensorflow::gtl::FlatMap cache_; diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc index 76e7e70..68f247b 100644 --- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc +++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc @@ -40,12 +40,14 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase { IndexedArrayAnalysis indexed_tensor_analysis; ParseAndVerifyModule(hlo_text); - string result = indexed_tensor_analysis.ToString( + TF_ASSERT_OK_AND_ASSIGN( + IndexedArrayAnalysis::Array* const array_result, indexed_tensor_analysis.GetArrayFor( - module().entry_computation()->root_instruction()), - print_constants); - LOG(INFO) << result; - ASSERT_EQ(result, root_expression); + module().entry_computation()->root_instruction())); + string string_result = + indexed_tensor_analysis.ToString(array_result, print_constants); + LOG(INFO) << string_result; + ASSERT_EQ(string_result, root_expression); } }; -- 2.7.4