} while (IndexUtil::BumpIndices(window_shape, &window_index));
}
+// Creates a vector of multipliers which can be used to create a linear index
+// into shape.
+//
+// Given the multidimensional index {i1, ..., iN} and
+// M = MakeDimMultipliers(shape), the corresponding linear index LI is simply
+//
+// LI = i1 * M[1] + i2 * M[2] + ... + iN * M[N].
+//
+// This lets you calculate LI given the multidimensional indices in any order.
+DimensionVector MakeDimMultipliers(const Shape& shape) {
+ DimensionVector v(ShapeUtil::Rank(shape));
+ int64 scale = 1;
+ for (auto dim : LayoutUtil::MinorToMajor(shape)) {
+ v[dim] = scale;
+ scale *= shape.dimensions(dim);
+ }
+ return v;
+}
+
} // namespace
template <typename ReturnT, typename ElementwiseT>
const Shape& window_shape =
ShapeUtil::MakeShape(rhs_shape.element_type(), window_dimension_sizes);
- DimensionVector lhs_index(lhs_rank);
- DimensionVector rhs_index(rhs_rank);
+ DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
+ DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
+
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
auto lhs_literal_data = lhs_literal.data<ReturnT>();
auto func = [&](ArraySlice<int64> out_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
-
- std::fill(lhs_index.begin(), lhs_index.end(), 0);
- std::fill(rhs_index.begin(), rhs_index.end(), 0);
std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
- lhs_index[input_batch_dim] = out_index[output_batch_dim];
- rhs_index[kernel_output_z_dim] = out_index[output_z_dim];
-
// Convolve input feature with kernel.
do {
for (int64 iz = 0; iz < z_size; ++iz) {
- lhs_index[input_z_dim] = iz;
- rhs_index[kernel_input_z_dim] = iz;
+ int64 lhs_linear_index = 0;
+ lhs_linear_index += out_index[output_batch_dim] *
+ lhs_dim_multipliers[input_batch_dim];
+ lhs_linear_index += iz * lhs_dim_multipliers[input_z_dim];
+
+ int64 rhs_linear_index = 0;
+ rhs_linear_index += out_index[output_z_dim] *
+ rhs_dim_multipliers[kernel_output_z_dim];
+ rhs_linear_index += iz * rhs_dim_multipliers[kernel_input_z_dim];
// Find corresponding spatial dimension index for input (lhs).
for (int64 ki = 0; ki < rhs_spatial_index.size(); ++ki) {
// Calculate the actual lhs (input) index after dilation. As an
// optimization, skip this integer divide if there's no dilation.
+ int64 lhs_spatial_index;
if (window_dim.base_dilation() > 1) {
- lhs_index[input_spatial_dim] =
- undilated_index / window_dim.base_dilation();
+ lhs_spatial_index = undilated_index / window_dim.base_dilation();
} else {
- lhs_index[input_spatial_dim] = undilated_index;
+ lhs_spatial_index = undilated_index;
}
+ lhs_linear_index +=
+ lhs_spatial_index * lhs_dim_multipliers[input_spatial_dim];
- // Skip if input index is not in bound.
- if (!(lhs_index[input_spatial_dim] >= 0 &&
- lhs_index[input_spatial_dim] <
+ // Skip if input index is not in bounds.
+ if (!(lhs_spatial_index >= 0 &&
+ lhs_spatial_index <
lhs_shape.dimensions(input_spatial_dim))) {
goto cnt;
}
- rhs_index[dnums.kernel_spatial_dimensions(ki)] =
- window_dim.window_reversal()
- ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
- : rhs_spatial_index[ki];
+ rhs_linear_index +=
+ (window_dim.window_reversal()
+ ? ((window_dim.size() - 1) - rhs_spatial_index[ki])
+ : rhs_spatial_index[ki]) *
+ rhs_dim_multipliers[dnums.kernel_spatial_dimensions(ki)];
}
- auto lhs_elem = static_cast<ElementwiseT>(
- lhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex(
- lhs_shape, lhs_index)]);
- auto rhs_elem = static_cast<ElementwiseT>(
- rhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex(
- rhs_shape, rhs_index)]);
- result_val += lhs_elem * rhs_elem;
+ result_val +=
+ static_cast<ElementwiseT>(lhs_literal_data[lhs_linear_index]) *
+ static_cast<ElementwiseT>(rhs_literal_data[rhs_linear_index]);
}
cnt : {}
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));