[XLA] Don't call MultidimensionalIndexToLinearIndex in HloEvaluator's convolution...
authorJustin Lebar <jlebar@google.com>
Fri, 6 Apr 2018 02:30:10 +0000 (19:30 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 6 Apr 2018 02:32:55 +0000 (19:32 -0700)
Before: ConvolutionTest/0.StridedFilter (41812 ms)
After:  ConvolutionTest/0.StridedFilter (28054 ms)

Speedup: 42 / 28 = 1.5x
PiperOrigin-RevId: 191835735

tensorflow/compiler/xla/service/hlo_evaluator.cc

index 4bec953..53ad890 100644 (file)
@@ -202,6 +202,25 @@ void IterateThroughWindow(
   } while (IndexUtil::BumpIndices(window_shape, &window_index));
 }
 
+// Creates a vector of multipliers which can be used to create a linear index
+// into shape.
+//
+// Given the multidimensional index {i1, ..., iN} and
+// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply
+//
+//   LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N].
+//
+// This lets you calculate LI given the multidimensional indices in any order.
+DimensionVector MakeDimMultipliers(const Shape& shape) {
+  DimensionVector v(ShapeUtil::Rank(shape));
+  int64 scale = 1;
+  for (auto dim : LayoutUtil::MinorToMajor(shape)) {
+    v[dim] = scale;
+    scale *= shape.dimensions(dim);
+  }
+  return v;
+}
+
 }  // namespace
 
 template <typename ReturnT, typename ElementwiseT>
@@ -999,8 +1018,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
     const Shape& window_shape =
         ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
 
-    DimensionVector lhs_index(lhs_rank);
-    DimensionVector rhs_index(rhs_rank);
+    DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
+    DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
+
     DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
 
     auto lhs_literal_data = lhs_literal.data<ReturnT>();
@@ -1008,19 +1028,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
 
     auto func = [&](ArraySlice<int64> out_index) {
       ElementwiseT result_val = static_cast<ElementwiseT>(0);
-
-      std::fill(lhs_index.begin(), lhs_index.end(), 0);
-      std::fill(rhs_index.begin(), rhs_index.end(), 0);
       std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
 
-      lhs_index[input_batch_dim] = out_index[output_batch_dim];
-      rhs_index[kernel_output_z_dim] = out_index[output_z_dim];
-
       // Convolve input feature with kernel.
       do {
         for (int64 iz = 0; iz < z_size; ++iz) {
-          lhs_index[input_z_dim] = iz;
-          rhs_index[kernel_input_z_dim] = iz;
+          int64 lhs_linear_index = 0;
+          lhs_linear_index += out_index[output_batch_dim] *
+                              lhs_dim_multipliers[input_batch_dim];
+          lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
+
+          int64 rhs_linear_index = 0;
+          rhs_linear_index += out_index[output_z_dim] *
+                              rhs_dim_multipliers[kernel_output_z_dim];
+          rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
 
           // Find corresponding spatial dimension index for input (lhs).
           for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
@@ -1045,33 +1066,32 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault {
 
             // Calculate the actual lhs (input) index after dilation.  As an
             // optimization, skip this integer divide if there's no dilation.
+            int64 lhs_spatial_index;
             if (window_dim.base_dilation() > 1) {
-              lhs_index[input_spatial_dim] =
-                  undilated_index / window_dim.base_dilation();
+              lhs_spatial_index = undilated_index / window_dim.base_dilation();
             } else {
-              lhs_index[input_spatial_dim] = undilated_index;
+              lhs_spatial_index = undilated_index;
             }
+            lhs_linear_index +=
+                lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
 
-            // Skip if input index is not in bound.
-            if (!(lhs_index[input_spatial_dim] >= 0 &&
-                  lhs_index[input_spatial_dim] <
+            // Skip if input index is not in bounds.
+            if (!(lhs_spatial_index >= 0 &&
+                  lhs_spatial_index <
                       lhs_shape.dimensions(input_spatial_dim))) {
               goto cnt;
             }
 
-            rhs_index[dnums.kernel_spatial_dimensions(ki)] =
-                window_dim.window_reversal()
-                    ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
-                    : rhs_spatial_index[ki];
+            rhs_linear_index +=
+                (window_dim.window_reversal()
+                     ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
+                     : rhs_spatial_index[ki]) *
+                rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
           }
 
-          auto lhs_elem = static_cast<ElementwiseT>(
-              lhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex(
-                  lhs_shape, lhs_index)]);
-          auto rhs_elem = static_cast<ElementwiseT>(
-              rhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex(
-                  rhs_shape, rhs_index)]);
-          result_val += lhs_elem * rhs_elem;
+          result_val +=
+              static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) *
+              static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
         }
       cnt : {}
       } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));