// NB: OMP pragmas have to get their own functions; can't put them in lambdas
template <typename scalar_t>
-void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, Scalar beta, const Tensor& t, Scalar alpha, const Tensor& csr, const Tensor& indices, const Tensor& values, const Tensor& dense) {
- int64_t h, i;
+void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, int64_t dim_k, Tensor& r, Scalar beta, const Tensor& t, Scalar alpha, const Tensor& indices, const Tensor& values, const Tensor& dense) {
+ int64_t i;
// r_ = alpha * sparse * dense
scalar_t cast_alpha = alpha.to<scalar_t>();
at::mul_out(r, t, scalar_to_tensor(beta));
}
- auto csr_accessor = csr.accessor<int64_t, 1>();
auto indices_accessor = indices.accessor<int64_t, 2>();
auto values_accessor = values.accessor<scalar_t, 1>();
int64_t dense_stride1 = dense.stride(1);
int64_t r_stride0 = r.stride(0);
int64_t r_stride1 = r.stride(1);
-#pragma omp parallel for private(h, i) schedule(static) if (nnz > 10000)
- for (h = 0; h < dim_i; h++) {
- int64_t i_start = csr_accessor[h];
- int64_t i_end = csr_accessor[h+1];
- for (i = i_start; i < i_end; i++) {
- scalar_t val = values_accessor[i];
- int64_t col = indices_accessor[1][i];
- if (col >= 0 && col < dim_j) {
- THBlas_axpy<scalar_t>(dim_k,
+ for (i = 0; i < nnz; i++) {
+ scalar_t val = values_accessor[i];
+ int64_t row = indices_accessor[0][i];
+ int64_t col = indices_accessor[1][i];
+ if (col >= 0 && col < dim_j && row >= 0 && row < dim_i) {
+ THBlas_axpy<scalar_t>(dim_k,
cast_alpha * val,
dense_ptr + col * dense_stride0, dense_stride1,
- r_ptr + h * r_stride0, r_stride1);
+ r_ptr + row * r_stride0, r_stride1);
+ } else {
+ if (col < 0 || col >= dim_j) {
+ AT_ERROR("addmm: index out of column bound: ", col, " not between 1 and ", dim_j);
} else {
- AT_ERROR("addmm: index out of bound: ", col, " not between 1 and ", dim_j);
+ AT_ERROR("addmm: index out of row bound: ", row, " not between 1 and ", dim_i);
}
}
}
AT_CHECK(sparse_.dense_dim() == 0, "addmm: scalar values expected, got ", sparse_.dense_dim(), "D values");
AT_CHECK(dense.dim() == 2, "addmm: matrices expected, got ", dense.dim(), "D tensor");
- SparseTensor sparse = sparse_.coalesce();
-
// ixj * jxk = ixk
- int64_t dim_i = sparse.size(0);
- int64_t dim_j = sparse.size(1);
+ int64_t dim_i = sparse_.size(0);
+ int64_t dim_j = sparse_.size(1);
int64_t dim_k = dense.size(1);
AT_CHECK(dense.size(0) == dim_j,
r.resize_({dim_i, dim_k});
- int64_t nnz = sparse._nnz();
+ int64_t nnz = sparse_._nnz();
if (nnz == 0) {
at::mul_out(r, t, at::scalar_tensor(beta, r.options()));
return r;
}
- LongTensor indices = sparse._indices();
- Tensor values = sparse._values();
- LongTensor csr = _to_csr(indices.data<int64_t>(), dim_i, nnz);
+ LongTensor indices = sparse_._indices();
+ Tensor values = sparse_._values();
AT_DISPATCH_ALL_TYPES(
values.type(), "addmm_sparse_dense", [&] {
- s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, csr, indices, values, dense);
+ s_addmm_out_sparse_dense_worker<scalar_t>(nnz, dim_i, dim_j, dim_k, r, beta, t, alpha, indices, values, dense);
}
);