"""Compute definition of sparse_dense"""
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
-reg.register_schedule("nn.sparse_dense", strategy.schedule_sparse_dense)
+reg.register_strategy("nn.sparse_dense", strategy.sparse_dense_strategy)
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
plevel=15)
return strategy
+
+@sparse_dense_strategy.register(["cuda", "gpu"])
+def sparse_dense_strategy_cuda(attrs, inputs, out_type, target):
+ """sparse dense cuda strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(
+ wrap_compute_sparse_dense(topi.cuda.sparse_dense),
+ wrap_topi_schedule(topi.cuda.schedule_sparse_dense),
+ name="sparse_dense.cuda",
+ plevel=10)
+ return strategy
+
+
@argsort_strategy.register(["cuda", "gpu"])
def argsort_strategy_cuda(attrs, inputs, out_type, target):
"""argsort cuda strategy"""
name="batch_matmul.generic")
return strategy
-# sparse_dense
-@generic_func
-def schedule_sparse_dense(attrs, outs, target):
- """schedule sparse_dense"""
- with target:
- return topi.generic.schedule_sparse_dense(outs)
+# sparse dense
+def wrap_compute_sparse_dense(topi_compute):
+ """wrap sparse dense topi compute"""
+ def _compute_sparse_dense(attrs, inputs, out_type):
+ return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
+ return _compute_sparse_dense
+
+@override_native_generic_func("sparse_dense_strategy")
+def sparse_dense_strategy(attrs, inputs, out_type, target):
+ """sparse dense generic strategy"""
+ logger.warning("sparse dense is not optimized for this platform.")
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense),
+ wrap_topi_schedule(topi.generic.schedule_sparse_dense),
+ name="sparse_dense.generic")
+ return strategy
# sparse_transpose
@generic_func
plevel=15)
return strategy
-@schedule_sparse_dense.register("cpu")
-def schedule_sparse_dense_cpu(attrs, outs, target):
- """schedule sparse_dense for x86"""
- with target:
- return topi.x86.schedule_sparse_dense(outs)
+@sparse_dense_strategy.register("cpu")
+def sparse_dense_strategy_cpu(attrs, inputs, out_type, target):
+ """sparse dense x86 strategy"""
+ strategy = _op.OpStrategy()
+ strategy.add_implementation(wrap_compute_sparse_dense(topi.nn.sparse_dense),
+ wrap_topi_schedule(topi.x86.schedule_sparse_dense),
+ name="sparse_dense.x86",
+ plevel=10)
+ return strategy
+
@roi_align_strategy.register("cpu")
def roi_align_strategy_cpu(attrs, inputs, out_type, target):
"""Create schedule for sparse dense"""
# pylint:disable=invalid-name
s = te.create_schedule([x.op for x in outs])
-
def _callback(op):
if op.tag == "sparse_dense_bsrmm":
y_bsrmm = op.input_tensors[0]
from ..util import traverse_inline, get_const_int
from .util import get_fp32_len
-
def schedule_sparse_dense(outs):
"""Create schedule for sparse dense"""
s = te.create_schedule([x.op for x in outs])
-
def _callback(op):
simd_width = get_fp32_len()
if op.tag == "sparse_dense_csrmm" and op != outs[0].op: