[Bugfix] Fix sort changing original input data issue (#3212)
authorLeyuan Wang <laurawly@gmail.com>
Wed, 22 May 2019 18:09:01 +0000 (11:09 -0700)
committerHaichen Shen <shenhaichen@gmail.com>
Wed, 22 May 2019 18:09:01 +0000 (11:09 -0700)
* sort bugfix for not rearranging input data

* separate sort schedule

* fix lint

* use identity op instead

* fix lint

* remove redundent code

src/op/extern_op.cc
topi/python/topi/cuda/nms.py
topi/python/topi/cuda/sort.py
topi/python/topi/cuda/vision.py

index e6c6039..7023aeb 100644 (file)
@@ -72,7 +72,10 @@ Operation ExternOpNode::make(std::string name,
   CHECK_EQ(inputs.size(), input_placeholders.size());
   for (size_t i = 0; i < inputs.size(); ++i) {
     CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
-    CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape));
+    CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
+    for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
+        CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
+    }
     CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
   }
   n->inputs = std::move(inputs);
index 0c27bd2..925cf24 100644 (file)
@@ -24,6 +24,7 @@ from tvm.generic import cast
 from tvm.intrin import if_then_else, log, power
 from topi.vision import non_max_suppression, get_valid_counts
 from .sort import argsort
+from .. import tag
 
 
 def get_valid_counts_pre(data, flag, idx, score_threshold):
@@ -730,7 +731,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
                                       "valid_count_buf", data_alignment=4)
     score_axis = score_index
     score_shape = (batch_size, num_anchors)
-    score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
+    score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
     sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
 
     sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
index 99ba852..678d494 100644 (file)
@@ -20,6 +20,10 @@ import tvm
 
 from tvm import api
 from topi.sort import argsort
+from topi.math import identity
+from .. import generic
+from .. import tag
+
 
 def sort_ir(data, output, axis, is_ascend):
     """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
@@ -104,8 +108,6 @@ def sort_ir(data, output, axis, is_ascend):
 
     return ib.get()
 
-
-
 def sort_nms_ir(data, valid_count, output, axis, is_ascend):
     """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
 
@@ -221,29 +223,60 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
     out : tvm.Tensor
         The output of this function.
     """
-    data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
+    sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8)
+    sorted_data = identity(data)
     if flag:
         valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
                                           "valid_count_buf", data_alignment=4)
         out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
         out = tvm.extern([data.shape],
-                         [data, valid_count],
+                         [sorted_data, valid_count],
                          lambda ins, outs: sort_nms_ir(
                              ins[0], ins[1], outs[0], axis, is_ascend),
                          dtype="int32",
-                         in_buffers=[data_buf, valid_count_buf],
+                         in_buffers=[sorted_data_buf, valid_count_buf],
                          out_buffers=[out_buf],
                          name="argsort_nms_gpu",
                          tag="argsort_nms_gpu")
     else:
         out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
         out = tvm.extern([data.shape],
-                         [data],
+                         [sorted_data],
                          lambda ins, outs: sort_ir(
                              ins[0], outs[0], axis, is_ascend),
                          dtype=dtype,
-                         in_buffers=[data_buf],
+                         in_buffers=[sorted_data_buf],
                          out_buffers=[out_buf],
                          name="argsort_gpu",
                          tag="argsort_gpu")
     return out
+
+@generic.schedule_argsort.register(["cuda", "gpu"])
+def schedule_argsort(outs):
+    """Schedule for argsort operator.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of argsort
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+      The computation schedule for the op.
+    """
+    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
+    s = tvm.create_schedule([x.op for x in outs])
+    scheduled_ops = []
+    from .injective import _schedule_injective
+    def traverse(op):
+        if tag.is_broadcast(op.tag):
+            _schedule_injective(op, s)
+        for tensor in op.input_tensors:
+            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
+                traverse(tensor.op)
+        scheduled_ops.append(op)
+    traverse(outs[0].op)
+
+    return s
index 78f5c1f..968e554 100644 (file)
@@ -25,41 +25,17 @@ from .pooling import schedule_pool
 
 def _default_schedule(outs):
     """Default schedule for gpu."""
-    target = tvm.target.current_target()
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
     scheduled_ops = []
-
+    from .injective import _schedule_injective
     def traverse(op):
-        """inline all one-to-one-mapping operators except the last stage (output)"""
-        if op.tag in ["nms", "invalid_to_bottom"]:
-            if op.tag == "nms":
-                sort = op.input_tensors[1]
-            else:
-                out = op.input_tensors[0]
-                sort = s[out].op.input_tensors[1]
-            score = s[sort].op.input_tensors[0]
-            fused = s[score].fuse(*s[score].op.axis)
-            num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads)
-            bx, tx = s[score].split(fused, factor=num_thread)
-            s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
-            s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
-        if tag.is_broadcast(op.tag):
-            if op not in s.outputs:
-                s[op].compute_inline()
-            else:
-                x = op.output(0)
-                fused = s[x].fuse(*s[x].op.axis)
-                num_thread = tvm.target.current_target(allow_none=False).max_num_threads
-                bx, tx = s[x].split(fused, factor=num_thread)
-                s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
-                s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
-            for tensor in op.input_tensors:
-                if tensor.op.input_tensors and tensor.op not in scheduled_ops:
-                    traverse(tensor.op)
-
+        if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']:
+            _schedule_injective(op, s)
+        for tensor in op.input_tensors:
+            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
+                traverse(tensor.op)
         scheduled_ops.append(op)
-
     traverse(outs[0].op)
     return s
 
@@ -173,19 +149,7 @@ def schedule_proposal(outs):
     s: Schedule
       The computation schedule for the op.
     """
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
-    from .injective import _schedule_injective
-    def traverse(op):
-        if op.tag in ['bbox_score', 'sorted_bbox']:
-            _schedule_injective(op, s)
-        for tensor in op.input_tensors:
-            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
-                traverse(tensor.op)
-        scheduled_ops.append(op)
-    traverse(outs[0].op)
-    return s
+    return _default_schedule(outs)
 
 @generic.schedule_get_valid_counts.register(["cuda", "gpu"])
 def schedule_get_valid_counts(outs):
@@ -203,30 +167,3 @@ def schedule_get_valid_counts(outs):
       The computation schedule for the op.
     """
     return _default_schedule(outs)
-
-@generic.schedule_argsort.register(["cuda", "gpu"])
-def schedule_argsort(outs):
-    """Schedule for argsort operator.
-
-    Parameters
-    ----------
-    outs: Array of Tensor
-        The computation graph description of argsort
-        in the format of an array of tensors.
-
-    Returns
-    -------
-    s: Schedule
-      The computation schedule for the op.
-    """
-    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    s = tvm.create_schedule([x.op for x in outs])
-    scheduled_ops = []
-    from .injective import _schedule_injective
-    def traverse(op):
-        for tensor in op.input_tensors:
-            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
-                traverse(tensor.op)
-        scheduled_ops.append(op)
-    traverse(outs[0].op)
-    return s