[Relay] [TOPI] `{relay,topi}.nn.sparse_transpose` for **Square** CSR matrices (#3707)
authorYulun Yao <allen980123@gmail.com>
Tue, 6 Aug 2019 01:13:22 +0000 (18:13 -0700)
committerThierry Moreau <moreau@uw.edu>
Tue, 6 Aug 2019 01:13:22 +0000 (18:13 -0700)
* add build gcn tutorial

* add transpose operator for square sparse matrices

* remove extra files

* change loop tag

* comply with lint

* comply with lint -- line too long

* comply with lint

* lint check

* lint check

* lint check

* apply marisa and theirry's reviews

include/tvm/relay/attrs/nn.h
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/nn/nn.py
src/relay/op/nn/sparse.cc
topi/python/topi/generic/nn.py
topi/python/topi/nn/sparse.py
topi/tests/python/test_topi_sparse.py

index ca4e8d3..58c4bba 100644 (file)
@@ -371,6 +371,11 @@ struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
   TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
 };
 
+/*! \brief Attributes for sparse_transpose operator */
+struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
+  TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {}
+};
+
 /*! \brief Attributes for upsampling operator */
 struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
   int scale;
index b50a27b..0c374b8 100644 (file)
@@ -99,6 +99,20 @@ def schedule_sparse_dense(attrs, outputs, target):
 
 reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
+# sparse_transpose
+@reg.register_compute("nn.sparse_transpose")
+def compute_sparse_transpose(attrs, inputs, out_type, target):
+    """Compute definition of sparse_transpose"""
+    return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2])
+
+@reg.register_schedule("nn.sparse_transpose")
+def schedule_sparse_transpose(attrs, outputs, target):
+    """Schedule definition of batch_matmul"""
+    with target:
+        return topi.generic.schedule_sparse_transpose(outputs)
+
+reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
+
 # conv2d
 def _find_conv2d_op(op):
     """Find the op with conv2d in its tag by traversing."""
index 46c01be..4a83ef2 100644 (file)
@@ -954,6 +954,33 @@ def sparse_dense(data, weight):
     """
     return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
 
+def sparse_transpose(x):
+    r"""
+    Computes the fast matrix transpose of x,
+    where x is a sparse tensor in CSR format (represented as a namedtuple
+    with fields `data`, `indices`, and `indptr`).
+
+    ** Currently only support Square Matrices **
+
+    .. math::
+
+        \mbox{sparse_transpose}(x)[n, n] = (x^T)[n, n]
+
+    Please refer to https://github.com/scipy/scipy/blob/v1.3.0/scipy/sparse/csr.py
+    for the algorithm implemented in this operator.
+
+    Parameters
+    ----------
+    x : namedtuple.
+        The sparse weight matrix for the fast matrix transpose.
+
+    Returns
+    -------
+    result : relay.Tuple([tvm.relay.Expr, tvm.relay.Expr, tvm.relay.Expr])
+        Tuple of output sparse tensor (same shape and format as input),
+        i.e. if CSR then output is in ([data, indices, indptr]) form
+    """
+    return TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)
 
 def contrib_conv2d_winograd_without_weight_transform(data,
                                                      weight,
index 3e81787..48a9b11 100644 (file)
@@ -72,26 +72,72 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig
 }
 
 TVM_REGISTER_API("relay.op.nn._make.sparse_dense")
-    .set_body([](const TVMArgs& args, TVMRetValue* rv) {
-      runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
-    });
+.set_body([](const TVMArgs& args, TVMRetValue* rv) {
+  runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
+});
 
 RELAY_REGISTER_OP("nn.sparse_dense")
-    .describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
+.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
 
 - **data**: `(x1, x2, ..., xn, input_dim)`
 - **weight**: `(units, input_dim)`
 - **out**: `(x1, x2, ..., xn, units)`.
 
 )code" TVM_ADD_FILELINE)
-    .set_attrs_type_key("relay.attrs.SparseDenseAttrs")
-    .set_num_inputs(4)
-    .add_argument("data", "nD Tensor", "Input data.")
-    .add_argument("weight_data", "1D Tensor", "Weight data matrix.")
-    .add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
-    .add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
-    .set_support_level(1)
-    .add_type_rel("SparseDense", SparseDenseRel);
+.set_attrs_type_key("relay.attrs.SparseDenseAttrs")
+.set_num_inputs(4)
+.add_argument("data", "nD Tensor", "Input data.")
+.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
+.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
+.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
+.set_support_level(1)
+.add_type_rel("SparseDense", SparseDenseRel);
+
+// relay.nn.sparse_transpose
+TVM_REGISTER_NODE_TYPE(SparseTransposeAttrs);
+
+bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                    const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 4);
+  const auto* sparse_data = types[0].as<TensorTypeNode>();
+  CHECK_EQ(sparse_data->shape.size(), 1);
+  const auto* sparse_indices = types[1].as<TensorTypeNode>();
+  CHECK_EQ(sparse_indices->shape.size(), 1);
+  const auto* sparse_indptr = types[2].as<TensorTypeNode>();
+
+  std::vector<Type> output_types;
+  output_types.push_back(TensorTypeNode::make(sparse_data->shape, sparse_data->dtype));
+  output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype));
+  output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype));
+
+  reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(output_types)));
+  return true;
+}
+
+Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) {
+  auto attrs = make_node<SparseTransposeAttrs>();
+  static const Op& op = Op::Get("nn.sparse_transpose");
+  return CallNode::make(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_API("relay.op.nn._make.sparse_transpose")
+.set_body_typed(MakeSparseTranspose);
+
+
+RELAY_REGISTER_OP("nn.sparse_transpose")
+.describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix
+
+- **input**: `(N, N)`
+- **out**: `(N, N)`.
+
+)code" TVM_ADD_FILELINE)
+.set_attrs_type_key("relay.attrs.SparseTransposeAttrs")
+.set_num_inputs(3)
+.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.")
+.add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.")
+.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.")
+.set_support_level(1)
+.add_type_rel("SparseTranspose", SparseTransposeRel);
 
 }  // namespace relay
 }  // namespace tvm
index 59ee700..38b6632 100644 (file)
@@ -544,6 +544,23 @@ def schedule_sparse_dense(outs):
     return _default_schedule(outs, False)
 
 @tvm.target.generic_func
+def schedule_sparse_transpose(outs):
+    """Schedule for sparse_transpose
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of sparse_transpose
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
+
+@tvm.target.generic_func
 def schedule_batch_matmul(outs):
     target = tvm.target.current_target(allow_none=False)
     cpp_target = cpp.TEST_create_target(target.target_name)
index 17b30ad..11116b2 100644 (file)
@@ -101,3 +101,106 @@ def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
         (m, num_blocks * bs_r),
         lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r],
         tag="sparse_dense_bsrmm")
+
+@tvm.target.generic_func
+def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
+    """
+    Transpose a square sparse matrix,
+    `A` is an n-by-n sparse matrix in the CSR format.
+    ** Currently only support Square Matrices **
+
+    Parameters
+    ----------
+    sparse_data : tvm.Tensor
+        1-D with shape [nonzeros], dtype of 'float32'
+
+    sparse_indices : tvm.Tensor
+        1-D with shape [nonzeros], dtype of 'int32'
+
+    sparse_indptr : tvm.Tensor
+        1-D with shape [n+1], dtype of 'int32'
+
+    Returns
+    -------
+    out_data : tvm.Tensor
+        1-D with shape [nonzeros], dtype of 'float32'
+
+    out_indices : tvm.Tensor
+        1-D with shape [nonzeros], dtype of 'int32'
+
+    out_indptr : tvm.Tensor
+        1-D with shape [n+1], dtype of 'int32'
+    """
+    assert len(sparse_data.shape) == 1, "error in data dimension"
+    assert len(sparse_indices.shape) == 1, "error in indices dimension"
+    assert len(sparse_indptr.shape) == 1, "error in indptr dimension"
+
+    nnz = get_const_tuple(sparse_data.shape)[0]
+    n = get_const_tuple(sparse_indptr.shape)[0] - 1
+    output_shape = [(nnz,), (nnz,), (n+1,)]
+
+    # TODO: Add BSR transpose support
+
+    output_data, output_indices, output_indptr = tvm.extern(
+        shape=output_shape,
+        inputs=[sparse_data, sparse_indices, sparse_indptr],
+        fcompute=lambda ins, outs:
+        csr_transpose_ir(ins[0], ins[1], ins[2], outs[0], outs[1], outs[2]),
+        tag="sparse_transpose_csr",
+        dtype=['float32', 'int32', 'int32'],
+        name='out')
+
+    return [output_data, output_indices, output_indptr]
+
+def csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr):
+    """define ir for csr_transpose"""
+    irb = tvm.ir_builder.create()
+
+    data_ptr = irb.buffer_ptr(data)
+    indices_ptr = irb.buffer_ptr(indices)
+    indptr_ptr = irb.buffer_ptr(indptr)
+
+    out_data_ptr = irb.buffer_ptr(out_data)
+    out_indices_ptr = irb.buffer_ptr(out_indices)
+    out_indptr_ptr = irb.buffer_ptr(out_indptr)
+
+    n = get_const_tuple(indptr.shape)[0] - 1
+    nnz = get_const_tuple(data.shape)[0]
+
+    with irb.for_range(0, n, for_type="parallel", name='col') as col:
+        out_indptr_ptr[col] = 0
+
+    with irb.for_range(0, nnz, for_type="serial", name='nz_idx') as nz_idx:
+        out_indptr_ptr[indices_ptr[nz_idx]] += 1
+
+    cumsum = irb.allocate('int32', (1,), name='cumsum', scope='local')
+    temp = irb.allocate('int32', (1,), name='temp', scope='local')
+    cumsum[0] = 0
+    with irb.for_range(0, n, for_type="serial", name='col') as col:
+        temp[0] = out_indptr_ptr[col]
+        out_indptr_ptr[col] = cumsum[0]
+        cumsum[0] += temp[0]
+
+    out_indptr_ptr[n] = nnz
+
+    with irb.for_range(0, n, for_type="serial", name='row') as row:
+        offset = indptr_ptr[row]
+        diff = indptr_ptr[row+1] - indptr_ptr[row]
+        with irb.for_range(0, diff, for_type="serial", name='idx') as idx:
+            real_idx = offset + idx
+            col = indices_ptr[real_idx]
+            dest = out_indptr_ptr[col]
+
+            out_indices_ptr[dest] = row
+            out_data_ptr[dest] = data_ptr[real_idx]
+            out_indptr_ptr[col] += 1
+
+    last = irb.allocate('int32', (1,), name='last', scope='local')
+    temp2 = irb.allocate('int32', (1,), name='temp2', scope='local')
+    last[0] = 0
+    with irb.for_range(0, n, for_type="serial", name="col") as col:
+        temp2[0] = out_indptr_ptr[col]
+        out_indptr_ptr[col] = last[0]
+        last[0] = temp2[0]
+
+    return irb.get()
index 49324b7..1b40b13 100644 (file)
@@ -23,6 +23,7 @@ from topi.util import get_const_tuple
 import tvm.contrib.sparse as tvmsp
 from collections import namedtuple
 import time
+import scipy.sparse as sp
 
 def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
     nr, nc, n = tvm.var("nr"), tvm.var("nc"), tvm.var("n")
@@ -217,7 +218,6 @@ def test_dense():
 
 
 def test_sparse_dense_csr():
-    import scipy.sparse as sp
     M, N, K, density = 1, 17, 47, 0.2
     X_np = np.random.randn(M, K).astype("float32")
     W_sp_np = sp.random(N, K, density=density, format='csr', dtype="float32")
@@ -235,9 +235,34 @@ def test_sparse_dense_csr():
     func(tvm.ndarray.array(X_np), tvm.ndarray.array(W_sp_np.data), tvm.ndarray.array(W_sp_np.indices), tvm.ndarray.array(W_sp_np.indptr), Y_tvm)
     tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
 
+def test_sparse_transpose_csr():
+    N, density = 1023, 0.3
+
+    X_sp = sp.random(N, N, density=density, format='csr', dtype='float32')
+
+    X_sp_T = X_sp.transpose()
+    X_np_T = X_sp_T.todense()
+
+    X_data = tvm.placeholder(shape=X_sp.data.shape, dtype=str(X_sp.data.dtype))
+    X_indices = tvm.placeholder(shape=X_sp.indices.shape, dtype=str(X_sp.indices.dtype))
+    X_indptr = tvm.placeholder(shape=X_sp.indptr.shape, dtype=str(X_sp.indptr.dtype))
+    
+    X_T_data, X_T_indices, X_T_indptr = topi.nn.sparse_transpose(X_data, X_indices, X_indptr)
+    s = tvm.create_schedule([X_T_data.op, X_T_indices.op, X_T_indptr.op])
+    func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr])
+
+
+    X_T_data_tvm = tvm.ndarray.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
+    X_T_indices_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
+    X_T_indptr_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))
+
+    func(tvm.ndarray.array(X_sp.data), tvm.ndarray.array(X_sp.indices), tvm.ndarray.array(X_sp.indptr),
+        X_T_data_tvm,  X_T_indices_tvm, X_T_indptr_tvm)
+
+    X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense()
+    tvm.testing.assert_allclose(X_np_T, X_T_out, atol=1e-4, rtol=1e-4)
 
 def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
-    import scipy.sparse as sp
     import itertools
     Y = np.zeros((M, N), dtype=dtype)
     assert M % BS_R == 0
@@ -318,3 +343,4 @@ if __name__ == "__main__":
     test_csrmm()
     test_dense()
     test_sparse_dense()
+    test_sparse_transpose_csr()