} \
}()
-#define AT_DISPATCH_ALL_TYPES_AND_HALF(TYPE, NAME, ...) \
- [&] { \
- const at::Type& the_type = TYPE; \
- switch (the_type.scalarType()) { \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
- default: \
- AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
- } \
- }()
+template <at::ScalarType N>
+struct MyTemplate;
-#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
- [&] { \
- const at::Type& the_type = TYPE; \
- switch (the_type.scalarType()) { \
- AT_PRIVATE_CASE_TYPE( \
- at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE( \
- at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
- default: \
- AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
- } \
- }()
+template<>
+struct MyTemplate<at::ScalarType::Half> {
+ using type = at::Half;
+};
-#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \
- [&] { \
- const at::Type& the_type = TYPE; \
- switch (the_type.scalarType()) { \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE( \
- at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE( \
- at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
- default: \
- AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
- } \
+#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \
+ [&] { \
+ const at::Type& the_type = TYPE; \
+ switch (the_type.scalarType()) { \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate<SCALARTYPE>::type, __VA_ARGS__) \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
+ } \
}()
-#define AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(TYPE, NAME, ...) \
- [&] { \
- const at::Type& the_type = TYPE; \
- switch (the_type.scalarType()) { \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE( \
- at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
- AT_PRIVATE_CASE_TYPE( \
- at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
- default: \
- AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
- } \
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \
+ [&] { \
+ const at::Type& the_type = TYPE; \
+ switch (the_type.scalarType()) { \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Byte, uint8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Char, int8_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Int, int32_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Long, int64_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Short, int16_t, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE, MyTemplate<SCALARTYPE>::type, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE( \
+ at::ScalarType::ComplexFloat, std::complex<float>, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE( \
+ at::ScalarType::ComplexDouble, std::complex<double>, __VA_ARGS__) \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", the_type.toString(), "'"); \
+ } \
}()
template <typename self_T>
void _copy__cpu(at::Tensor& self, const at::Tensor& src) {
AT_CHECK(self.numel() == src.numel(), "sizes do not match");
- AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cpu", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cpu", [&]() {
_copy__cpu<self_T, scalar_t>(self, src);
});
}
_s_copy_from(src, self, non_blocking);
return self;
}
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- self.type(), "_copy__cpu", [&]() { ::_copy__cpu<scalar_t>(self, src); });
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, self.type(), "_copy__cpu", [&]() { ::_copy__cpu<scalar_t>(self, src); });
return self;
}
}
Tensor buf = empty({BLOCK_SZ, BLOCK_SZ}, self.options());
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- self.type(), "_copy_same_type_transpose_", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, self.type(), "_copy_same_type_transpose_", [&]() {
scalar_t* sp = src.data<scalar_t>();
scalar_t* rp = self.data<scalar_t>();
scalar_t* bp = buf.data<scalar_t>();
} else {
#ifdef _OPENMP
if (!in_parallel_region()) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() {
- at::CPU_tensor_parallel_apply2<scalar_t, scalar_t>(
- self, src, [](scalar_t& self_val, const scalar_t& src_val) {
- self_val = src_val;
- });
- });
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() {
+ at::CPU_tensor_parallel_apply2<scalar_t, scalar_t>(
+ self, src, [](scalar_t& self_val, const scalar_t& src_val) {
+ self_val = src_val;
+ });
+ });
} else {
serial_path = true;
}
}
if (serial_path) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy_same_type_", [&]() {
- at::CPU_tensor_apply2<scalar_t, scalar_t>(
- self, src, [](scalar_t& self_val, const scalar_t& src_val) {
- self_val = src_val;
- });
- });
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, self.type(), "_copy_same_type_", [&]() {
+ at::CPU_tensor_apply2<scalar_t, scalar_t>(
+ self, src, [](scalar_t& self_val, const scalar_t& src_val) {
+ self_val = src_val;
+ });
+ });
}
}
Scalar _local_scalar_dense_cpu(const Tensor& self) {
Scalar r;
- AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX(
- self.type(), "_local_scalar_dense_cpu", [&] {
+ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(
+ at::ScalarType::Half, self.type(), "_local_scalar_dense_cpu", [&] {
scalar_t value = *self.data<scalar_t>();
r = Scalar(value);
});
constexpr int64_t COPY_GRAIN_SIZE = 20000;
static void copy_kernel_impl(Tensor& dst, const Tensor& src) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(dst.type(), "copy_kernel_impl", [&]() {
- scalar_t* self_ptr = dst.data<scalar_t>();
- scalar_t* src_ptr = src.data<scalar_t>();
-
- auto sample = [&](int64_t begin, int64_t end) {
- int64_t len = end - begin;
- scalar_t* self_seg = self_ptr + begin;
- scalar_t* src_seg = src_ptr + begin;
- at::vec256::convert<scalar_t, scalar_t>(src_seg, self_seg, len);
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, dst.type(), "copy_kernel_impl", [&]() {
+ scalar_t* self_ptr = dst.data<scalar_t>();
+ scalar_t* src_ptr = src.data<scalar_t>();
+
+ auto sample = [&](int64_t begin, int64_t end) {
+ int64_t len = end - begin;
+ scalar_t* self_seg = self_ptr + begin;
+ scalar_t* src_seg = src_ptr + begin;
+ at::vec256::convert<scalar_t, scalar_t>(src_seg, self_seg, len);
};
parallel_for(0, dst.numel(), COPY_GRAIN_SIZE, sample);
}
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(0), "index", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index", [&] {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)dst = *(scalar_t*)(src + offset);
});
void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
// NOTE: duplicate indices are only supported if accumulate is true.
- AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(0), "index_put", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(0), "index_put", [&] {
if (accumulate) {
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,
// this needs to be thread-safe.
}
static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar value) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "threshold", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "threshold", [&] {
threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
});
}
self_row_stride = self.stride(-2), self_col_stride = self.stride(-1);
dim3 dim_block = cuda::getApplyBlock();
dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches);
- AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), name, [&]{
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), name, [&]{
triu_tril_kernel<scalar_t, upper>
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
result.data<scalar_t>(), self.data<scalar_t>(), k, mat_size,
}
static void add_kernel_cuda(TensorIterator& iter, Scalar alpha_scalar) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "add", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "add", [&]() {
add_kernel_impl<scalar_t>(iter, alpha_scalar);
});
}
}
static void mul_kernel_cuda(TensorIterator& iter) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "mul", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "mul", [&]() {
mul_kernel_impl<scalar_t>(iter);
});
}
Scalar _local_scalar_dense_cuda(const Tensor& self) {
Scalar r;
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- self.type(), "_local_scalar_dense_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, self.type(), "_local_scalar_dense_cuda", [&] {
scalar_t value;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
cudaMemcpyHostToDevice,
stream));
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
- AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_from_cpu", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu", [&]() {
copy_device_to_device<scalar_t, scalar_t>(dst, dst_contig);
});
}
CUDAGuard device_guard(dst.device());
CUDAStream stream = getCurrentCUDAStream();
- AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_from_cpu_async", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_from_cpu_async", [&]() {
AT_CUDA_CHECK(cudaMemcpyAsync(
dst.data<scalar_t>(),
src.data<scalar_t>(),
CUDAGuard device_guard(src.device());
CUDAStream stream = getCurrentCUDAStream();
- AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "copy_to_cpu_async", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "copy_to_cpu_async", [&]() {
AT_CUDA_CHECK(cudaMemcpyAsync(
dst.data<scalar_t>(),
src.data<scalar_t>(),
template <typename dst_T>
void _copy__cuda(Tensor& dst, const Tensor& src, bool non_blocking) {
AT_CHECK(dst.numel() == src.numel(), "sizes do not match");
- AT_DISPATCH_ALL_TYPES_AND_HALF(src.type(), "_copy__cuda", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, src.type(), "_copy__cuda", [&]() {
if (dst.is_cuda() && src.is_cuda()) {
copy_device_to_device<dst_T, scalar_t>(dst, src);
} else if (dst.is_cuda()) {
namespace native {
Tensor& _s_copy__cuda(Tensor& self, const Tensor& src, bool non_blocking) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "_copy__cuda", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "_copy__cuda", [&]() {
::_copy__cuda<scalar_t>(self, src, non_blocking);
});
return self;
Tensor& bernoulli_tensor_cuda_(Tensor &self, const Tensor& p_, Generator* gen) {
auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA)));
- AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "bernoulli_tensor_cuda_self_", [&] {
- const at::Type& p_type = p.type();
- using self_t = scalar_t;
- auto seeds = next_philox_seed(gen, 10);
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.type(), "bernoulli_tensor_cuda_p_", [&] {
- using p_t = scalar_t;
- return bernoulli_tensor_cuda_kernel<self_t, p_t>(self, p, seeds);
- });
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, self.type(), "bernoulli_tensor_cuda_self_", [&] {
+ const at::Type& p_type = p.type();
+ using self_t = scalar_t;
+ auto seeds = next_philox_seed(gen, 10);
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.type(), "bernoulli_tensor_cuda_p_", [&] {
+ using p_t = scalar_t;
+ return bernoulli_tensor_cuda_kernel<self_t, p_t>(self, p, seeds);
+ });
});
return self;
}
Tensor& bernoulli_scalar_cuda_(Tensor &self, double p, Generator* gen) {
AT_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
- AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "bernoulli_scalar_cuda_", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "bernoulli_scalar_cuda_", [&] {
auto seeds = next_philox_seed(gen, 10);
bernoulli_scalar_cuda_kernel<scalar_t>(self, p, seeds);
});
}
static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_kernel_impl<dtype>(iter, index_size, index_stride);
});
static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
- AT_DISPATCH_ALL_TYPES_AND_HALF(iter.type(), "index_put", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.type(), "index_put", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_put_kernel_impl<dtype>(iter, index_size, index_stride);
});
}
Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(result.type(), "range", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "range", [&]() {
using accscalar_t = at::acc_type<scalar_t, true>;
auto xstart = start.to<accscalar_t>();
auto xend = end.to<accscalar_t>();
}
Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(result.type(), "arange", [&]() {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, result.type(), "arange", [&]() {
using accscalar_t = at::acc_type<scalar_t, true>;
auto xstart = start.to<accscalar_t>();
auto xend = end.to<accscalar_t>();
int64_t k,
int64_t dim,
bool keepdim) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "kthvalue", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "kthvalue", [&] {
kthvalue_cuda_template<scalar_t>(values, indices, self, k, dim, keepdim);
});
return std::forward_as_tuple(values, indices);
}
Tensor median_cuda(const Tensor& self) {
- return AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), "median", [&] {
+ return AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.type(), "median", [&] {
return median_cuda_template<scalar_t>(self);
});
}
const Tensor& self,
const Tensor& other) {
Tensor ret = at::empty(self.sizes(), self.options());
- AT_DISPATCH_ALL_TYPES_AND_HALF(ret.type(), "where", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, ret.type(), "where", [&] {
where_cuda<scalar_t>(ret, condition, self, other);
});
return ret;
cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()),
"unable to get dim grid");
- AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "tril_indices_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "tril_indices_cuda", [&] {
tril_indices_kernel<<<
dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
tensor.data<scalar_t>(),
cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()),
"unable to get dim grid");
- AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "triu_indices_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, tensor.type(), "triu_indices_cuda", [&] {
triu_indices_kernel<<<
dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
tensor.data<scalar_t>(),
// use kernel_pointwise_flip_apply2 only when to-flip dim is the 1st or last dim, where collapseDims can reduce the amount of work
if (flip_dims_size == 1 && in_tensor.is_contiguous() && (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
- AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] {
auto in_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(in_tensor);
auto out_tensor_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(out_tensor);
int flip_dim = in_tensor_info.collapseDims(flip_dims[0]);
}
}
- AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "flip_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "flip_cuda", [&] {
flip_cuda_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
in_tensor.data<scalar_t>(), out_tensor.data<scalar_t>(), N, flip_dims_t.toType(CUDA(kLong)).data<int64_t>(), flip_dims_size,
strides_t.toType(CUDA(kLong)).data<int64_t>(), stride_contiguous.toType(CUDA(kLong)).data<int64_t>(), shape_t.toType(CUDA(kLong)).data<int64_t>(), total_dims);
auto total_dims = in_tensor.dim();
- AT_DISPATCH_ALL_TYPES_AND_HALF(in_tensor.type(), "roll_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, in_tensor.type(), "roll_cuda", [&] {
roll_cuda_kernel<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
in_tensor.data<scalar_t>(), out_tensor.data<scalar_t>(), N,
dim, start,
int64_t stride = at::prod_intlist(values.sizes().slice(1));
dim3 grid(THCCeilDiv(newNnz, (int64_t) 4), THCCeilDiv(stride, (int64_t) 128));
dim3 block(32, 4);
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- values.type(), "coalesce_sparse_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half,values.type(), "coalesce_sparse_cuda", [&] {
using cuda_accscalar_t = acc_type<scalar_t, /* is_cuda */ true>;
apply::coalesceValuesKernel<scalar_t, cuda_accscalar_t><<<grid, block, 0, stream>>>(
uniqueOffsets.data<int64_t>(),
if (sparse.dense_dim() == 0) {
AT_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions");
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- values.type(), "add_out_dense_sparse_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
apply::sparseElementwiseKernelScalar<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
<<<grid, block, 0, stream>>>(
TensorCAddOp<scalar_t>(value.to<scalar_t>()),
// sparseElementwiseKernel needs values to be contiguous too
values = values.contiguous();
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- values.type(), "add_out_dense_sparse_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
apply::sparseElementwiseKernel<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
<<<grid, block, 0, stream>>>(
TensorCAddOp<scalar_t>(value.to<scalar_t>()),
// FIXME: at some point we can wrap the scale into indexAdd
// NB: Purposely not inplace!
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- values.type(), "add_out_dense_sparse_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, values.type(), "add_out_dense_sparse_cuda", [&] {
if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
values = values.mul(value);
}
LongTensor s_indices_ = src._indices();
Tensor s_values_ = src._values();
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- s_values_.type(), "add_out_sparse_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, s_values_.type(), "add_out_sparse_cuda", [&] {
if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
s_values_ = s_values_.mul(value);
}
AT_CHECK(cuda::getApplyGrid(valueSize, grid, curDevice), "mul: Argument #0: tensor too large or too many dimensions");
LongTensor resultNnz = at::empty({1}, CUDA(kLong));
- AT_DISPATCH_ALL_TYPES_AND_HALF(
- t_values_.type(), "mul_out_sparse_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND(
+ at::ScalarType::Half, t_values_.type(), "mul_out_sparse_cuda", [&] {
apply::valueSparseIntersectionKernel<TensorMulOp<scalar_t>, uint64_t, scalar_t>
<<<grid, block, 0, stream>>>(
TensorMulOp<scalar_t>(),