[TOPI] Update softmax compute and CPU schedule (#3680)
authorJon Soifer <soiferj@gmail.com>
Mon, 5 Aug 2019 02:46:28 +0000 (19:46 -0700)
committerYao Wang <kevinthesunwy@gmail.com>
Mon, 5 Aug 2019 02:46:28 +0000 (19:46 -0700)
* Update Softmax compute and CPU schedule

* Add C++ compute

* Fix schedule

* Update CUDA and OpenGL schedules

* Fix log_softmax

* Fix hls and opengl schedules

* Fix CUDA schedule

topi/include/topi/cuda/softmax.h
topi/include/topi/nn/softmax.h
topi/python/topi/cuda/softmax.py
topi/python/topi/hls/nn.py
topi/python/topi/nn/softmax.py
topi/python/topi/opengl/softmax.py
topi/python/topi/x86/nn.py

index ee27476..33be899 100644 (file)
@@ -50,13 +50,32 @@ inline Schedule schedule_softmax(const Target &target, const Array<Tensor>& outs
   auto s = create_schedule(out_ops);
 
   auto softmax = outs[0];
-  auto max_elem = softmax->op->InputTensors()[1];
-  auto expsum = softmax->op->InputTensors()[2];
+  tvm::Tensor max_elem;
+  tvm::Tensor expsum;
+  tvm::Tensor exp;
+  bool has_exp = false;
+
+  auto tag = softmax->op.as<ComputeOpNode>()->tag;
+  if (tag == "softmax_output") {
+    expsum = softmax->op->InputTensors()[1];
+    exp = softmax->op->InputTensors()[0];
+    max_elem = s[exp]->op->InputTensors()[1];
+    has_exp = true;
+  } else if (tag == "log_softmax_output") {
+    max_elem = softmax->op->InputTensors()[1];
+    expsum = softmax->op->InputTensors()[2];
+  } else {
+    LOG(ERROR) << "Tag is expected to be softmax_output or log_softmax_output. Got " << tag;
+  }
 
   int num_thread = 64;
   auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
   auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
 
+  if (has_exp) {
+    s[exp].bind(exp->op.as<ComputeOpNode>()->axis[0], block_x);
+  }
+
   s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
 
   auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
index a9ad1d7..8aa388a 100644 (file)
@@ -62,6 +62,9 @@ inline Tensor softmax(const Tensor &x,
   auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2");
   auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);
 
+  tvm::Map<std::string, NodeRef> attrs;
+  attrs.Set("axis", Integer(axis));
+
   auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
                                           const IterVar &reduce_index) {
     Array<Expr> eval_range;
@@ -75,35 +78,48 @@ inline Tensor softmax(const Tensor &x,
     return eval_range;
   };
 
+  auto get_non_reduce_indices = [axis, ndim](const Array<Var> &indices) {
+    Array<Expr> non_reduce_indices;
+    for (size_t i = 0; i < ndim; ++i) {
+      if (static_cast<int>(i) != axis)
+        non_reduce_indices.push_back(indices[i]);
+    }
+    return non_reduce_indices;
+  };
+
   auto _compute_max = [&](const Array<Var> &indices) {
     auto eval_range = insert_reduce_index(indices, k1);
     return topi::MaxOp(x(eval_range), {k1});
   };
 
-  auto _compute_expsum = [&](const Tensor &max_elem,
+  auto _compute_exp = [&](const Tensor &max_elem,
+                          const Array<Var> &indices) {
+    auto non_reduce_indices = get_non_reduce_indices(indices);
+    return tvm::exp(x(indices) - max_elem(non_reduce_indices));
+  };
+
+  auto _compute_expsum = [&](const Tensor &exp,
                              const Array<Var> &indices) {
     auto eval_range = insert_reduce_index(indices, k2);
-    return tvm::sum(tvm::exp(x(eval_range) - max_elem(indices)), {k2});
+    return tvm::sum(exp(eval_range), {k2});
   };
 
-  auto _normalize = [&](const Tensor &max_elem, const Tensor &expsum,
+  auto _normalize = [&](const Tensor &exp, const Tensor &expsum,
                         const Array<Var> &indices) {
-    Array<Expr> non_reduce_indices;
-    for (size_t i = 0; i < ndim; ++i) {
-      if (static_cast<int>(i) != axis)
-        non_reduce_indices.push_back(indices[i]);
-    }
-    return tvm::exp(x(indices) - max_elem(non_reduce_indices)) /
-           expsum(non_reduce_indices);
+    auto non_reduce_indices = get_non_reduce_indices(indices);
+    return exp(indices) / expsum(non_reduce_indices);
   };
 
   auto max_elem = tvm::compute(reduced_shape, _compute_max);
+  auto exp = tvm::compute(input_shape, [&](const Array<Var> &indices) {
+      return _compute_exp(max_elem, indices);
+  });
   auto expsum = tvm::compute(reduced_shape, [&](const Array<Var> &indices) {
-      return _compute_expsum(max_elem, indices);
+      return _compute_expsum(exp, indices);
   });
   return tvm::compute(input_shape, [&](const Array<Var> &indices) {
-      return _normalize(max_elem, expsum, indices);
-  }, name, tag);
+      return _normalize(exp, expsum, indices);
+  }, name, tag, attrs);
 }
 
 /*!
index bb885fc..09b6ef8 100644 (file)
@@ -38,17 +38,35 @@ def schedule_softmax(outs):
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
     softmax = outs[0]
-    max_elem = softmax.op.input_tensors[1]
-    expsum = softmax.op.input_tensors[2]
+
+    op_tag = softmax.op.tag
+    if op_tag == 'softmax_output':
+        expsum = softmax.op.input_tensors[1]
+        exp = softmax.op.input_tensors[0]
+        max_elem = s[exp].op.input_tensors[1]
+    elif op_tag == 'log_softmax_output':
+        exp = None
+        max_elem = softmax.op.input_tensors[1]
+        expsum = softmax.op.input_tensors[2]
+    else:
+        raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
+                         Got {0}'.format(op_tag))
 
     if len(softmax.shape) > 2:
-        for op in [max_elem.op, expsum.op, softmax.op]:
+        ops = [max_elem.op, expsum.op, softmax.op]
+        if exp != None:
+            ops.append(exp.op)
+            
+        for op in ops:
             s = _schedule_injective(op, s)
     else:
         num_thread = 64
         block_x = tvm.thread_axis("blockIdx.x")
         thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
 
+        if exp != None:
+            s[exp].bind(exp.op.axis[0], block_x)
+
         s[max_elem].bind(max_elem.op.axis[0], block_x)
         k = expsum.op.reduce_axis[0]
         ko, ki = s[expsum].split(k, factor=num_thread)
index 0658b02..8cc23ca 100644 (file)
@@ -261,8 +261,22 @@ def schedule_softmax(outs):
     tvm.schedule.AutoInlineInjective(s)
 
     softmax = outs[0]
-    max_elem = softmax.op.input_tensors[1]
-    expsum = softmax.op.input_tensors[2]
+
+    op_tag = softmax.op.tag
+    if op_tag == 'softmax_output':
+        expsum = softmax.op.input_tensors[1]
+        exp = softmax.op.input_tensors[0]
+        max_elem = s[exp].op.input_tensors[1]
+    elif op_tag == 'log_softmax_output':
+        exp = None
+        max_elem = softmax.op.input_tensors[1]
+        expsum = softmax.op.input_tensors[2]
+    else:
+        raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
+                         Got {0}'.format(op_tag))
+
+    if exp != None:
+        s[exp].compute_at(s[softmax], s[softmax].op.axis[1])
 
     s[expsum].compute_at(s[softmax], s[softmax].op.axis[1])
     s[max_elem].compute_at(s[softmax], s[softmax].op.axis[1])
index 00bbe55..16ffd79 100644 (file)
@@ -48,24 +48,33 @@ def softmax(x, axis=-1):
     def insert_reduce_index(indices, reduce_index):
         return indices[:axis] + (reduce_index,) + indices[axis:]
 
+    def get_non_reduce_indices(indices):
+        return tuple([var for (i, var) in enumerate(indices) if i != axis])
+
     def _compute_max(*indices):
         eval_range = insert_reduce_index(indices, k1)
         return tvm.max(x[eval_range], axis=k1)
 
-    def _compute_expsum(max_elem, *indices):
+    def _compute_exp(max_elem, *indices):
+        non_reduce_indices = get_non_reduce_indices(indices)
+        return tvm.exp(x[indices] - max_elem[non_reduce_indices])
+
+    def _compute_expsum(exp, *indices):
         eval_range = insert_reduce_index(indices, k2)
-        return tvm.sum(tvm.exp(x[eval_range] - max_elem[indices]), axis=k2)
+        return tvm.sum(exp[eval_range], axis=k2)
 
-    def _normalize(max_elem, expsum, *indices):
-        non_reduce_indices = tuple([var for (i, var) in enumerate(indices) if i != axis])
-        return tvm.exp(x[indices] - max_elem[non_reduce_indices]) / expsum[non_reduce_indices]
+    def _normalize(exp, expsum, *indices):
+        non_reduce_indices = get_non_reduce_indices(indices)
+        return exp[indices] / expsum[non_reduce_indices]
 
     reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
     max_elem = tvm.compute(reduced_shape, _compute_max, name='T_softmax_maxelem')
-    expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices),
+    exp = tvm.compute(shape, lambda *indices: _compute_exp(max_elem, *indices),
+                      name='T_softmax_exp')
+    expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(exp, *indices),
                          name='T_softmax_expsum')
-    return tvm.compute(shape, lambda *indices: _normalize(max_elem, expsum, *indices),
-                       name='T_softmax_norm')
+    return tvm.compute(shape, lambda *indices: _normalize(exp, expsum, *indices),
+                       name='T_softmax_norm', attrs={"axis" : axis})
 
 
 @tvm.tag_scope(tag='log_softmax_output')
index b9d9c29..96218e0 100644 (file)
@@ -37,8 +37,23 @@ def schedule_softmax(outs):
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
     s = tvm.create_schedule([x.op for x in outs])
     softmax = outs[0]
-    max_elem = softmax.op.input_tensors[1]
-    expsum = softmax.op.input_tensors[2]
+
+    op_tag = softmax.op.tag
+    if op_tag == 'softmax_output':
+        expsum = softmax.op.input_tensors[1]
+        exp = softmax.op.input_tensors[0]
+        max_elem = s[exp].op.input_tensors[1]
+    elif op_tag == 'log_softmax_output':
+        exp = None
+        max_elem = softmax.op.input_tensors[1]
+        expsum = softmax.op.input_tensors[2]
+    else:
+        raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
+                         Got {0}'.format(op_tag))
+
+    if exp != None:
+        s[exp].opengl()
+
     s[max_elem].opengl()
     s[expsum].opengl()
     s[softmax].opengl()
index 445bc69..8e506da 100644 (file)
@@ -36,15 +36,34 @@ def schedule_softmax(outs):
         The computation schedule for the op.
     """
     outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
-    x = outs[0]
+    softmax = outs[0]
     s = tvm.create_schedule([x.op for x in outs])
-    tvm.schedule.AutoInlineInjective(s)
-    if len(s[x].op.axis) >= 5:
-        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
-        s[x].parallel(fused)
-    elif len(s[x].op.axis) >= 3:
-        fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
-        s[x].parallel(fused)
+
+    op_tag = softmax.op.tag
+    if op_tag == 'softmax_output':
+        exp = softmax.op.input_tensors[0]
+        expsum = softmax.op.input_tensors[1]
+        max_elem = s[exp].op.input_tensors[1]
+        axis = int(softmax.op.attrs['axis'])
+    elif op_tag == 'log_softmax_output':
+        exp = None
+        max_elem = softmax.op.input_tensors[1]
+        expsum = softmax.op.input_tensors[2]
+        axis = 1
     else:
-        s[x].parallel(s[x].op.axis[0])
+        raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
+                         Got {0}'.format(op_tag))
+
+    # only parallelize outer dimensions up to axis
+    outer_axes = [s[softmax].op.axis[i] for i in range(0, axis)]
+    fused_outer_axes = s[softmax].fuse(*outer_axes)
+    s[softmax].parallel(fused_outer_axes)
+
+    # move computations with the same outer dimensions under the same root
+    s[max_elem].compute_at(s[softmax], fused_outer_axes)
+    s[expsum].compute_at(s[softmax], fused_outer_axes)
+
+    if exp != None:
+        s[exp].compute_at(s[softmax], fused_outer_axes)
+
     return s