From 2248a3488c53f8b858e2a0b8be93d62c3056df36 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Thu, 5 Apr 2018 18:23:32 -0700 Subject: [PATCH] [XLA] Don't call Literal::Get in HloEvaluator's convolution loop. This speeds up the implementation of conv because Literal::Get calls Literal::Piece::data, which is relatively slow. Instead, we call Literal::Data() once and cache the result. Before: ConvolutionTest/0.StridedFilter (59094 ms) After: ConvolutionTest/0.StridedFilter (41812 ms) Speedup: 59/42 = 1.4x PiperOrigin-RevId: 191830741 --- tensorflow/compiler/xla/service/hlo_evaluator.cc | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 9d7251b..4bec953 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -1003,6 +1003,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector rhs_index(rhs_rank); DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); + auto lhs_literal_data = lhs_literal.data(); + auto rhs_literal_data = rhs_literal.data(); + auto func = [&](ArraySlice out_index) { ElementwiseT result_val = static_cast(0); @@ -1062,9 +1065,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { : rhs_spatial_index[ki]; } - result_val += - static_cast(lhs_literal.Get(lhs_index)) * - static_cast(rhs_literal.Get(rhs_index)); + auto lhs_elem = static_cast( + lhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex( + lhs_shape, lhs_index)]); + auto rhs_elem = static_cast( + rhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex( + rhs_shape, rhs_index)]); + result_val += lhs_elem * rhs_elem; } cnt : {} } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); -- 2.7.4