output_shape[axis] = segment_count;
auto output = at::empty(output_shape, data.options());
- AT_DISPATCH_INDEX_TYPES(lengths.type(), "_segment_reduce_cpu_kernel1", [&]() {
+ AT_DISPATCH_INDEX_TYPES(lengths.scalar_type(), "_segment_reduce_cpu_kernel1", [&]() {
const auto* lengths_data = lengths.data_ptr<index_t>();
_segment_reduce_cpu_kernel1(
reduction, data, lengths_data, axis, initial, output, segment_count);
auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
AT_DISPATCH_INDEX_TYPES(
- lengths_contig.type(), "_segment_reduce_cpu_backward_kernel1", [&]() {
+ lengths_contig.scalar_type(), "_segment_reduce_cpu_backward_kernel1", [&] {
const auto* lengths_data = lengths_contig.data_ptr<index_t>();
_segment_reduce_cpu_backward_kernel1(
grad_contig,