[XLA] Don't call Literal::Get in HloEvaluator's convolution loop.
authorJustin Lebar <jlebar@google.com>
Fri, 6 Apr 2018 01:23:32 +0000 (18:23 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 01:27:08 +0000 (18:27 -0700)
This speeds up the implementation of conv because Literal::Get calls
Literal::Piece::data, which is relatively slow.

Instead, we call Literal::Data() once and cache the result.

Before: ConvolutionTest/0.StridedFilter (59094 ms)
After:  ConvolutionTest/0.StridedFilter (41812 ms)

Speedup: 59/42 = 1.4x
PiperOrigin-RevId: 191830741

tensorflow/compiler/xla/service/hlo_evaluator.cc

index 9d7251b..4bec953 100644 (file)
@@ -1003,6 +1003,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
     DimensionVector rhs_index(rhs_rank);
     DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
 
+    auto lhs_literal_data = lhs_literal.data<ReturnT>();
+    auto rhs_literal_data = rhs_literal.data<ReturnT>();
+
     auto func = [&](ArraySlice<int64> out_index) {
       ElementwiseT result_val = static_cast<ElementwiseT>(0);
 
@@ -1062,9 +1065,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
                     : rhs_spatial_index[ki];
           }
 
-          result_val +=
-              static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
-              static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
+          auto lhs_elem = static_cast<ElementwiseT>(
+              lhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex(
+                  lhs_shape, lhs_index)]);
+          auto rhs_elem = static_cast<ElementwiseT>(
+              rhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex(
+                  rhs_shape, rhs_index)]);
+          result_val += lhs_elem * rhs_elem;
         }
       cnt : {}
       } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));