bool is_last_block_done = mark_block_finished();
if (is_last_block_done) {
- value = arg_t {};
+ value = ident;
if (config.should_warp_reduce()) {
index_t input_offset = threadIdx.x + threadIdx.y * blockDim.x;
index_t step = blockDim.x * blockDim.y;
x = torch.ones(65536, device='cuda', dtype=torch.float16)
self.assertEqual(x.mean(dtype=torch.float32), 1)
+ def test_prod_large(self):
+ # tests global reduction (should_global_reduce = true) in case of non-zero identity element
+ x = torch.ones(240000, device='cuda', dtype=torch.float32)
+ self.assertEqual(x.prod(), 1)
+
@staticmethod
def _select_broadcastable_dims(dims_full=None):
return _TestTorchMixin._select_broadcastable_dims(dims_full)