[topi] fix strategy for sparse dense cuda (#5782)
authorBing Xu <antinucleon@gmail.com>
Mon, 15 Jun 2020 01:54:15 +0000 (18:54 -0700)
committerGitHub <noreply@github.com>
Mon, 15 Jun 2020 01:54:15 +0000 (21:54 -0400)
python/tvm/relay/op/nn/_nn.py
python/tvm/relay/op/strategy/cuda.py
python/tvm/relay/op/strategy/generic.py
python/tvm/relay/op/strategy/x86.py
topi/python/topi/cuda/sparse.py
topi/python/topi/x86/sparse.py

index c09b873..1c76f57 100644 (file)
@@ -69,7 +69,7 @@ def compute_sparse_dense(attrs, inputs, out_type):
     """Compute definition of sparse_dense"""
     return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
 
-reg.register_schedule("nn.sparse_dense", strategy.schedule_sparse_dense)
+reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy)
 reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
index 4b019cf..e0091a1 100644 (file)
@@ -493,6 +493,19 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
             plevel=15)
     return strategy
 
+
+@sparse_dense_strategy.register(["cuda", "gpu"])
+def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
+    """sparse dense cuda strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_sparse_dense(topi.cuda.sparse_dense),
+        wrap_topi_schedule(topi.cuda.schedule_sparse_dense),
+        name="sparse_dense.cuda",
+        plevel=10)
+    return strategy
+
+
 @argsort_strategy.register(["cuda", "gpu"])
 def argsort_strategy_cuda(attrs, inputs, out_type, target):
     """argsort cuda strategy"""
index 4fa2b11..b1fb421 100644 (file)
@@ -599,12 +599,22 @@ def batch_matmul_strategy(attrs, inputs, out_type, target):
                                 name="batch_matmul.generic")
     return strategy
 
-# sparse_dense
-@generic_func
-def schedule_sparse_dense(attrs, outs, target):
-    """schedule sparse_dense"""
-    with target:
-        return topi.generic.schedule_sparse_dense(outs)
+# sparse dense
+def wrap_compute_sparse_dense(topi_compute):
+    """wrap sparse dense topi compute"""
+    def _compute_sparse_dense(attrs, inputs, out_type):
+        return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
+    return _compute_sparse_dense
+
+@override_native_generic_func("sparse_dense_strategy")
+def sparse_dense_strategy(attrs, inputs, out_type, target):
+    """sparse dense generic strategy"""
+    logger.warning("sparse dense is not optimized for this platform.")
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense),
+                                wrap_topi_schedule(topi.generic.schedule_sparse_dense),
+                                name="sparse_dense.generic")
+    return strategy
 
 # sparse_transpose
 @generic_func
index 0984e40..b02db41 100644 (file)
@@ -294,11 +294,16 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
                                     plevel=15)
     return strategy
 
-@schedule_sparse_dense.register("cpu")
-def schedule_sparse_dense_cpu(attrs, outs, target):
-    """schedule sparse_dense for x86"""
-    with target:
-        return topi.x86.schedule_sparse_dense(outs)
+@sparse_dense_strategy.register("cpu")
+def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
+    """sparse dense x86 strategy"""
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense),
+                                wrap_topi_schedule(topi.x86.schedule_sparse_dense),
+                                name="sparse_dense.x86",
+                                plevel=10)
+    return strategy
+
 
 @roi_align_strategy.register("cpu")
 def roi_align_strategy_cpu(attrs, inputs, out_type, target):
index fb875b7..5b57000 100644 (file)
@@ -63,7 +63,6 @@ def schedule_sparse_dense(cfg, outs):
     """Create schedule for sparse dense"""
     # pylint:disable=invalid-name
     s = te.create_schedule([x.op for x in outs])
-
     def _callback(op):
         if op.tag == "sparse_dense_bsrmm":
             y_bsrmm = op.input_tensors[0]
index 54a5af9..02cbd2d 100644 (file)
@@ -21,11 +21,9 @@ from tvm import te
 from ..util import traverse_inline, get_const_int
 from .util import get_fp32_len
 
-
 def schedule_sparse_dense(outs):
     """Create schedule for sparse dense"""
     s = te.create_schedule([x.op for x in outs])
-
     def _callback(op):
         simd_width = get_fp32_len()
         if op.tag == "sparse_dense_csrmm" and op != outs[0].op: