auto s = create_schedule(out_ops);
auto softmax = outs[0];
- auto max_elem = softmax->op->InputTensors()[1];
- auto expsum = softmax->op->InputTensors()[2];
+ tvm::Tensor max_elem;
+ tvm::Tensor expsum;
+ tvm::Tensor exp;
+ bool has_exp = false;
+
+ auto tag = softmax->op.as<ComputeOpNode>()->tag;
+ if (tag == "softmax_output") {
+ expsum = softmax->op->InputTensors()[1];
+ exp = softmax->op->InputTensors()[0];
+ max_elem = s[exp]->op->InputTensors()[1];
+ has_exp = true;
+ } else if (tag == "log_softmax_output") {
+ max_elem = softmax->op->InputTensors()[1];
+ expsum = softmax->op->InputTensors()[2];
+ } else {
+ LOG(ERROR) << "Tag is expected to be softmax_output or log_softmax_output. Got " << tag;
+ }
int num_thread = 64;
auto block_x = tvm::thread_axis(Range(), "blockIdx.x");
auto thread_x = tvm::thread_axis(Range(0, num_thread), "threadIdx.x");
+ if (has_exp) {
+ s[exp].bind(exp->op.as<ComputeOpNode>()->axis[0], block_x);
+ }
+
s[max_elem].bind(max_elem->op.as<ComputeOpNode>()->axis[0], block_x);
auto k = expsum->op.as<ComputeOpNode>()->reduce_axis[0];
auto k2 = tvm::reduce_axis(Range(0, input_shape[axis]), "k2");
auto reduced_shape = MakeReduceTargetShape({axis}, x, false, false);
+ tvm::Map<std::string, NodeRef> attrs;
+ attrs.Set("axis", Integer(axis));
+
auto insert_reduce_index = [axis, ndim](const Array<Var> &indices,
const IterVar &reduce_index) {
Array<Expr> eval_range;
return eval_range;
};
+ auto get_non_reduce_indices = [axis, ndim](const Array<Var> &indices) {
+ Array<Expr> non_reduce_indices;
+ for (size_t i = 0; i < ndim; ++i) {
+ if (static_cast<int>(i) != axis)
+ non_reduce_indices.push_back(indices[i]);
+ }
+ return non_reduce_indices;
+ };
+
auto _compute_max = [&](const Array<Var> &indices) {
auto eval_range = insert_reduce_index(indices, k1);
return topi::MaxOp(x(eval_range), {k1});
};
- auto _compute_expsum = [&](const Tensor &max_elem,
+ auto _compute_exp = [&](const Tensor &max_elem,
+ const Array<Var> &indices) {
+ auto non_reduce_indices = get_non_reduce_indices(indices);
+ return tvm::exp(x(indices) - max_elem(non_reduce_indices));
+ };
+
+ auto _compute_expsum = [&](const Tensor &exp,
const Array<Var> &indices) {
auto eval_range = insert_reduce_index(indices, k2);
- return tvm::sum(tvm::exp(x(eval_range) - max_elem(indices)), {k2});
+ return tvm::sum(exp(eval_range), {k2});
};
- auto _normalize = [&](const Tensor &max_elem, const Tensor &expsum,
+ auto _normalize = [&](const Tensor &exp, const Tensor &expsum,
const Array<Var> &indices) {
- Array<Expr> non_reduce_indices;
- for (size_t i = 0; i < ndim; ++i) {
- if (static_cast<int>(i) != axis)
- non_reduce_indices.push_back(indices[i]);
- }
- return tvm::exp(x(indices) - max_elem(non_reduce_indices)) /
- expsum(non_reduce_indices);
+ auto non_reduce_indices = get_non_reduce_indices(indices);
+ return exp(indices) / expsum(non_reduce_indices);
};
auto max_elem = tvm::compute(reduced_shape, _compute_max);
+ auto exp = tvm::compute(input_shape, [&](const Array<Var> &indices) {
+ return _compute_exp(max_elem, indices);
+ });
auto expsum = tvm::compute(reduced_shape, [&](const Array<Var> &indices) {
- return _compute_expsum(max_elem, indices);
+ return _compute_expsum(exp, indices);
});
return tvm::compute(input_shape, [&](const Array<Var> &indices) {
- return _normalize(max_elem, expsum, indices);
- }, name, tag);
+ return _normalize(exp, expsum, indices);
+ }, name, tag, attrs);
}
/*!
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0]
- max_elem = softmax.op.input_tensors[1]
- expsum = softmax.op.input_tensors[2]
+
+ op_tag = softmax.op.tag
+ if op_tag == 'softmax_output':
+ expsum = softmax.op.input_tensors[1]
+ exp = softmax.op.input_tensors[0]
+ max_elem = s[exp].op.input_tensors[1]
+ elif op_tag == 'log_softmax_output':
+ exp = None
+ max_elem = softmax.op.input_tensors[1]
+ expsum = softmax.op.input_tensors[2]
+ else:
+ raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
+ Got {0}'.format(op_tag))
if len(softmax.shape) > 2:
- for op in [max_elem.op, expsum.op, softmax.op]:
+ ops = [max_elem.op, expsum.op, softmax.op]
+ if exp != None:
+ ops.append(exp.op)
+
+ for op in ops:
s = _schedule_injective(op, s)
else:
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
+ if exp != None:
+ s[exp].bind(exp.op.axis[0], block_x)
+
s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread)
tvm.schedule.AutoInlineInjective(s)
softmax = outs[0]
- max_elem = softmax.op.input_tensors[1]
- expsum = softmax.op.input_tensors[2]
+
+ op_tag = softmax.op.tag
+ if op_tag == 'softmax_output':
+ expsum = softmax.op.input_tensors[1]
+ exp = softmax.op.input_tensors[0]
+ max_elem = s[exp].op.input_tensors[1]
+ elif op_tag == 'log_softmax_output':
+ exp = None
+ max_elem = softmax.op.input_tensors[1]
+ expsum = softmax.op.input_tensors[2]
+ else:
+ raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
+ Got {0}'.format(op_tag))
+
+ if exp != None:
+ s[exp].compute_at(s[softmax], s[softmax].op.axis[1])
s[expsum].compute_at(s[softmax], s[softmax].op.axis[1])
s[max_elem].compute_at(s[softmax], s[softmax].op.axis[1])
def insert_reduce_index(indices, reduce_index):
return indices[:axis] + (reduce_index,) + indices[axis:]
+ def get_non_reduce_indices(indices):
+ return tuple([var for (i, var) in enumerate(indices) if i != axis])
+
def _compute_max(*indices):
eval_range = insert_reduce_index(indices, k1)
return tvm.max(x[eval_range], axis=k1)
- def _compute_expsum(max_elem, *indices):
+ def _compute_exp(max_elem, *indices):
+ non_reduce_indices = get_non_reduce_indices(indices)
+ return tvm.exp(x[indices] - max_elem[non_reduce_indices])
+
+ def _compute_expsum(exp, *indices):
eval_range = insert_reduce_index(indices, k2)
- return tvm.sum(tvm.exp(x[eval_range] - max_elem[indices]), axis=k2)
+ return tvm.sum(exp[eval_range], axis=k2)
- def _normalize(max_elem, expsum, *indices):
- non_reduce_indices = tuple([var for (i, var) in enumerate(indices) if i != axis])
- return tvm.exp(x[indices] - max_elem[non_reduce_indices]) / expsum[non_reduce_indices]
+ def _normalize(exp, expsum, *indices):
+ non_reduce_indices = get_non_reduce_indices(indices)
+ return exp[indices] / expsum[non_reduce_indices]
reduced_shape = tuple([dim for (i, dim) in enumerate(shape) if i != axis])
max_elem = tvm.compute(reduced_shape, _compute_max, name='T_softmax_maxelem')
- expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(max_elem, *indices),
+ exp = tvm.compute(shape, lambda *indices: _compute_exp(max_elem, *indices),
+ name='T_softmax_exp')
+ expsum = tvm.compute(reduced_shape, lambda *indices: _compute_expsum(exp, *indices),
name='T_softmax_expsum')
- return tvm.compute(shape, lambda *indices: _normalize(max_elem, expsum, *indices),
- name='T_softmax_norm')
+ return tvm.compute(shape, lambda *indices: _normalize(exp, expsum, *indices),
+ name='T_softmax_norm', attrs={"axis" : axis})
@tvm.tag_scope(tag='log_softmax_output')
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
softmax = outs[0]
- max_elem = softmax.op.input_tensors[1]
- expsum = softmax.op.input_tensors[2]
+
+ op_tag = softmax.op.tag
+ if op_tag == 'softmax_output':
+ expsum = softmax.op.input_tensors[1]
+ exp = softmax.op.input_tensors[0]
+ max_elem = s[exp].op.input_tensors[1]
+ elif op_tag == 'log_softmax_output':
+ exp = None
+ max_elem = softmax.op.input_tensors[1]
+ expsum = softmax.op.input_tensors[2]
+ else:
+ raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
+ Got {0}'.format(op_tag))
+
+ if exp != None:
+ s[exp].opengl()
+
s[max_elem].opengl()
s[expsum].opengl()
s[softmax].opengl()
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
- x = outs[0]
+ softmax = outs[0]
s = tvm.create_schedule([x.op for x in outs])
- tvm.schedule.AutoInlineInjective(s)
- if len(s[x].op.axis) >= 5:
- fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
- s[x].parallel(fused)
- elif len(s[x].op.axis) >= 3:
- fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
- s[x].parallel(fused)
+
+ op_tag = softmax.op.tag
+ if op_tag == 'softmax_output':
+ exp = softmax.op.input_tensors[0]
+ expsum = softmax.op.input_tensors[1]
+ max_elem = s[exp].op.input_tensors[1]
+ axis = int(softmax.op.attrs['axis'])
+ elif op_tag == 'log_softmax_output':
+ exp = None
+ max_elem = softmax.op.input_tensors[1]
+ expsum = softmax.op.input_tensors[2]
+ axis = 1
else:
- s[x].parallel(s[x].op.axis[0])
+ raise ValueError('Tag is expected to be softmax_output or log_softmax_output. \
+ Got {0}'.format(op_tag))
+
+ # only parallelize outer dimensions up to axis
+ outer_axes = [s[softmax].op.axis[i] for i in range(0, axis)]
+ fused_outer_axes = s[softmax].fuse(*outer_axes)
+ s[softmax].parallel(fused_outer_axes)
+
+ # move computations with the same outer dimensions under the same root
+ s[max_elem].compute_at(s[softmax], fused_outer_axes)
+ s[expsum].compute_at(s[softmax], fused_outer_axes)
+
+ if exp != None:
+ s[exp].compute_at(s[softmax], fused_outer_axes)
+
return s