From 1cb3552c019d351bf740457e7d14da54324c5921 Mon Sep 17 00:00:00 2001 From: Nick Desaulniers Date: Wed, 16 May 2018 10:03:12 -0700 Subject: [PATCH] [TF:XLA:INTERPRETER] speed up select and scatter by avoiding memory allocation in loops HandleSelectAndScatter() has 2 IterateThroughWindow() blocks. Before, we spent (in percent total program time): 11.98% Literal::CreateR0() = 10.82% (block1) + 1.16% (block2) 4.91% Literal::~Literal() = 4.44% (block1) + 0.51% (block2) 1.52% operator delete = 1.38% (block1) + 0.14% (block2) ===== 18.41% total After: 1.99% Literal::~Literal() = 1.83% (block1) + 0.16% (block2) 0.68% operator delete = 0.61% (block1) + 0.07% (block2) ===== 2.67% total PiperOrigin-RevId: 196844177 --- .../xla/service/hlo_evaluator_typed_visitor.h | 37 ++++++++++++---------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 0e4ef08..b8744cc 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -1556,9 +1556,14 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { int64 rank = ShapeUtil::Rank(operand_literal.shape()); HloEvaluator embedded_evaluator(parent_->max_loop_iterations_); - DimensionVector source_index(rank); - - std::fill(source_index.begin(), source_index.end(), 0); + DimensionVector source_index(rank, 0); + + // Used in the dual IterateThroughWindow lambdas below. Hoisted to avoid + // dynamic memory allocations. + auto curr_val_literal = Literal::CreateR0(ReturnT()); + auto selected_val_literal = Literal::CreateR0(ReturnT()); + auto source_literal_scatter = Literal::CreateR0(ReturnT()); + auto scattered_literal = Literal::CreateR0(ReturnT()); do { // For each element in `source`, we place a window in `operand`. For each // window placement, we iterate inside the window twice: @@ -1582,14 +1587,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_val = curr_val; selected_index = operand_index; } - const auto curr_val_literal = Literal::CreateR0(curr_val); - const auto selected_val_literal = - Literal::CreateR0(*selected_val); - - const std::vector args = { - selected_val_literal.get(), curr_val_literal.get()}; + curr_val_literal->Set({}, curr_val); + selected_val_literal->Set({}, *selected_val); std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*select, args) + embedded_evaluator + .Evaluate( + *select, + {selected_val_literal.get(), curr_val_literal.get()}) .ConsumeValueOrDie(); bool selected = !computed_result->Get({}); if (selected) { @@ -1606,14 +1610,13 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { selected_index->begin())) { auto source = source_literal.Get(source_index); auto scattered = result->Get(operand_index); - const auto source_literal = Literal::CreateR0(source); - const auto scattered_literal = - Literal::CreateR0(scattered); - - const std::vector args = { - source_literal.get(), scattered_literal.get()}; + source_literal_scatter->Set({}, source); + scattered_literal->Set({}, scattered); std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*scatter, args) + embedded_evaluator + .Evaluate(*scatter, + {source_literal_scatter.get(), + scattered_literal.get()}) .ConsumeValueOrDie(); result->Set(operand_index, computed_result->Get({})); // Clear visit states so that the we can use the evaluator again -- 2.7.4