for (int index = get_global_id(0); index < channels * spatial_dim;
index += get_global_size(0)) {
int s = index % spatial_dim;
- out[n * channels * spatial_dim + index] = out_tmp[index] / scale_tmp[s];
+ Dtype v = out_tmp[index] / scale_tmp[s];
+#ifdef LOG_SOFTMAX
+ v = log(v);
+#endif
+ out[n * channels * spatial_dim + index] = v;
}
}
for (int index = get_global_id(0); index < channels * spatial_dim;
index += get_global_size(0)) {
int s = index % spatial_dim;
- out[n * channels * spatial_dim + index] /= scale[n * spatial_dim + s];
+ Dtype v = out[n * channels * spatial_dim + index] / scale[n * spatial_dim + s];
+#ifdef LOG_SOFTMAX
+ v = log(v);
+#endif
+ out[n * channels * spatial_dim + index] = v;
}
}