DimensionVector rhs_index(rhs_rank);
DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
+ auto lhs_literal_data = lhs_literal.data<ReturnT>();
+ auto rhs_literal_data = rhs_literal.data<ReturnT>();
+
auto func = [&](ArraySlice<int64> out_index) {
ElementwiseT result_val = static_cast<ElementwiseT>(0);
: rhs_spatial_index[ki];
}
- result_val +=
- static_cast<ElementwiseT>(lhs_literal.Get<ReturnT>(lhs_index)) *
- static_cast<ElementwiseT>(rhs_literal.Get<ReturnT>(rhs_index));
+ auto lhs_elem = static_cast<ElementwiseT>(
+ lhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex(
+ lhs_shape, lhs_index)]);
+ auto rhs_elem = static_cast<ElementwiseT>(
+ rhs_literal_data[IndexUtil::MultidimensionalIndexToLinearIndex(
+ rhs_shape, rhs_index)]);
+ result_val += lhs_elem * rhs_elem;
}
cnt : {}
} while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index));