Fixes reduction launch config (#64304)
authorXiang Gao <qasdfgtyuiop@gmail.com>
Wed, 1 Sep 2021 17:17:52 +0000 (10:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 1 Sep 2021 17:30:40 +0000 (10:30 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/48573
See also https://github.com/pytorch/pytorch/pull/64194

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64304

Reviewed By: janeyx99

Differential Revision: D30689600

Pulled By: ngimel

fbshipit-source-id: bf2103ca177fd3b6e27bc0324b81925234483a29

aten/src/ATen/native/cuda/LinearAlgebra.cu
aten/src/ATen/native/cuda/Normalization.cu
aten/src/ATen/native/cuda/Reduce.cuh

index b7ecf38..b4936c0 100644 (file)
@@ -5,7 +5,6 @@
 #include <ATen/native/LinearAlgebra.h>
 #include <ATen/native/DispatchStub.h>
 #include <ATen/native/cuda/Loops.cuh>
-#include <ATen/native/cuda/Reduce.cuh>
 #include <ATen/native/SharedReduceOps.h>
 #include <ATen/native/ReduceOps.h>
 
index 1d4d1cc..44e27a9 100644 (file)
@@ -2,7 +2,6 @@
 #include <ATen/native/ReduceOps.h>
 #include <ATen/native/Resize.h>
 #include <ATen/native/cuda/Loops.cuh>
-#include <ATen/native/cuda/Reduce.cuh>
 #include <ATen/native/cuda/Resize.cuh>
 #include <ATen/native/cuda/Normalization.cuh>
 #include <c10/cuda/CUDAMathCompat.h>
index b460045..3be7100 100644 (file)
@@ -989,14 +989,14 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
       // Map block.x to the fastest reducing dimension. It implies:
       //   1. block_x_reduce is required.
       //   2. block.y now max out to num_outputs.
-      dim0 = iter.shape()[0];
+      dim0 = inputs_per_output;
       dim1 = num_outputs;
       fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
     } else {
       // Map block.x to the fastest non reducing dimension. It implies:
       //   1. block_x_reduce is turned off.
       //   2. block.y now max out to inputs_per_output.
-      dim0 = iter.shape()[iter.num_reduce_dims()];
+      dim0 = num_outputs;
       dim1 = inputs_per_output;
       fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
     }