Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
+ /*! \brief scalar expression inputs */
+ Array<Expr> scalar_inputs;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
+ v->Visit("scalar_inputs", &scalar_inputs);
}
static Operation make(std::string name,
std::string tag,
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
- Array<Region> regions);
+ Array<Region> regions,
+ Array<Expr> scalar_inputs);
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, BaseComputeOpNode);
* When it is a constant, it means we can only take data in that shape.
*/
Array<Buffer> buffers;
+ /*! \brief List of scalar variables, used in body. These placeholders
+ * will be bound to expressions passed in when the TensorIntrin is called
+ * from a TensorComputeOp.
+ */
+ Array<Var> scalar_params;
/*! \brief The normal statement to execute the intrinsic */
Stmt body;
/*!
v->Visit("op", &op);
v->Visit("inputs", &inputs);
v->Visit("buffers", &buffers);
+ v->Visit("scalar_params", &scalar_params);
v->Visit("body", &body);
v->Visit("reduce_init", &reduce_init);
v->Visit("reduce_update", &reduce_update);
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
+ Array<Var> scalar_params,
Stmt body,
Stmt reduce_init,
Stmt reduce_update);
Array<Tensor> tensors;
/*! \brief regions of input tensors */
Array<Region> regions;
+
+
/*!
* \brief IterVar on each reduction axis, if the
* intrin will use the reduce axis
*/
Array<IterVar> reduce_axis;
+ /*! \brief scalar expression inputs */
+ Array<Expr> scalar_inputs;
+
void VisitAttrs(AttrVisitor* v) final {
v->Visit("intrin", &intrin);
v->Visit("tensors", &tensors);
v->Visit("regions", ®ions);
v->Visit("reduce_axis", &reduce_axis);
+ v->Visit("scalar_inputs", &scalar_inputs);
}
static TensorIntrinCall make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
- Array<IterVar> reduce_axis);
+ Array<IterVar> reduce_axis,
+ Array<Expr> scalar_inputs);
static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node);
out_ndim,
body.intrin,
body.tensors,
- body.regions)
+ body.regions,
+ body.scalar_inputs)
else:
if not isinstance(body, (list, tuple)):
body = [body]
decl_tensor_intrin: Construct a TensorIntrin
"""
def __call__(self, *args, **kwargs):
- tensors = [x.tensor for x in args]
- regions = [_get_region(x) for x in args]
+ tensors = [x.tensor for x in args if isinstance(x, _tensor.TensorSlice)]
+ scalar_inputs = [x for x in args if not isinstance(x, _tensor.TensorSlice)]
+ regions = [_get_region(x) for x in args if isinstance(x, _tensor.TensorSlice)]
reduce_axis = []
if "reduce_axis" in kwargs:
reduce_axis = kwargs["reduce_axis"]
if not isinstance(reduce_axis, (list, tuple)):
reduce_axis = [reduce_axis]
reduce_axis = _api.convert(reduce_axis)
- return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis)
+ if scalar_inputs:
+ scalar_inputs = _api.convert(scalar_inputs)
+ return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis, scalar_inputs)
def decl_tensor_intrin(op,
fcompute,
name="tensor_intrin",
- binds=None):
+ binds=None, scalar_params=None):
"""Declare a tensor intrinsic function.
Parameters
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
+ scalar_params: a list of variables used by op, whose values will be passed
+ as scalar_inputs when the tensor intrinsic is called.
+
Returns
-------
intrin: TensorIntrin
offset_factor=cfg.offset_factor))
binds_list.append(buf)
- body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
+ if scalar_params:
+ body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):], scalar_params)
+ else:
+ body = fcompute(binds_list[:len(inputs)], binds_list[len(inputs):])
+ scalar_params = []
if isinstance(body, (_expr.Expr, _stmt.Stmt)):
body = [body]
body = [_make.Evaluate(x) if isinstance(x, _expr.Expr) else x for x in body]
if len(body) < 3:
body += [None] * (3 - len(body))
return _api_internal._TensorIntrin(
- name, op, inputs, binds_list, *body)
+ name, op, inputs, binds_list, scalar_params, *body)
Operation op,
Array<Tensor> inputs,
Array<Buffer> buffers,
+ Array<Var> scalar_params,
Stmt body,
Stmt reduce_init,
Stmt reduce_update) {
n->op = std::move(op);
n->inputs = std::move(inputs);
n->buffers = std::move(buffers);
+ n->scalar_params = std::move(scalar_params);
n->body = std::move(body);
n->reduce_init = std::move(reduce_init);
n->reduce_update = std::move(reduce_update);
TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin,
Array<Tensor> tensors,
Array<Region> regions,
- Array<IterVar> reduce_axis) {
+ Array<IterVar> reduce_axis,
+ Array<Expr> scalar_inputs) {
auto n = make_node<TensorIntrinCallNode>();
n->intrin = std::move(intrin);
n->tensors = std::move(tensors);
n->regions = std::move(regions);
n->reduce_axis = std::move(reduce_axis);
+ n->scalar_inputs = std::move(scalar_inputs);
return TensorIntrinCall(n);
}
int schedulable_ndim,
TensorIntrin intrin,
Array<Tensor> tensors,
- Array<Region> regions) {
+ Array<Region> regions,
+ Array<Expr> scalar_inputs) {
auto n = make_node<TensorComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->intrin = std::move(intrin);
n->inputs = std::move(tensors);
n->input_regions = std::move(regions);
+ n->scalar_inputs = std::move(scalar_inputs);
return Operation(n);
}
std::unordered_map<const Variable*, Expr> vmap;
ir::ArgBinder binder(&vmap);
+ // Map the expressions passed in the call to the TensorIntrin, to the placeholder
+ // variables
+ Array<Expr> user_expr = this->scalar_inputs;
+ Array<Var> scalar_params = this->intrin->scalar_params;
+ Array<Expr> sp_expr;
+ for (auto sp : scalar_params) {
+ Expr esp = sp;
+ sp_expr.push_back(esp);
+ }
+ CHECK_EQ(sp_expr.size(), user_expr.size());
+ // TODO(jdavies-huawei): what name should be used here?
+ binder.BindArray(sp_expr, user_expr, this->name);
+
size_t tloc = stage->leaf_iter_vars.size();
ComputeLoopNest n = ComputeLoopNest::make(this, stage, dom_map, debug_keep_trivial_loop);
new_regions.push_back(region);
}
+ Array<Expr> new_scalar_inputs;
+ for (Expr old_input : tensor_op->scalar_inputs) {
+ new_scalar_inputs.push_back(VarReplacer(vsub2newvar).Mutate(old_input));
+ }
+
Operation cache_op = TensorComputeOpNode::make(
tensor_op->name + "." + scope, tensor_op->tag, new_axis,
tensor_op->reduce_axis, tensor_op->schedulable_ndim,
- tensor_op->intrin, tensor_op->inputs, new_regions);
+ tensor_op->intrin, tensor_op->inputs, new_regions, new_scalar_inputs);
// axis will be used in generating compute op
Array<IterVar> compute_axis = tensor_op->axis;
assert(s[z].iter_var_attrs[xi].tensor_intrin == intrin)
assert(s[z].iter_var_attrs[xi].iter_type == tvm.schedule.IterVar.Tensorized)
+def test_tensor_intrin_scalar_params():
+ n = tvm.var("n")
+ x = tvm.placeholder((n,), name='x')
+ v = tvm.var("v")
+ w = tvm.var("w")
+ z = tvm.compute((n,), lambda i: x[i]*v + w, name='z')
+
+ def intrin_func(ins, outs, sp):
+ assert(isinstance(ins[0], tvm.schedule.Buffer))
+ assert(ins[0].shape[0] == n)
+ assert(sp[0] == v)
+ assert(sp[1] == w)
+ return tvm.call_packed("hw_func", ins[0].data, outs[0].data, sp[0], sp[1])
+
+ with tvm.build_config(offset_factor=1):
+ intrin = tvm.decl_tensor_intrin(z.op, intrin_func, scalar_params=[v, w])
+ assert intrin.op == z.op
+ assert intrin.reduce_init is None
+ assert tuple(intrin.inputs) == tuple(z.op.input_tensors)
+ assert(intrin.buffers[0].shape[0] == n)
+ assert tuple(intrin.scalar_params) == tuple((v, w))
+
+ A = tvm.placeholder((10,10), name='A')
+ # Pass scalar inputs to the TensorIntrin, interleaved with tensor inputs
+ C = tvm.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C")
+ s = tvm.create_schedule(C.op)
+ stmt = tvm.lower(s, [A, C], simple_mode=True)
+ assert isinstance(stmt.body.body.body, tvm.stmt.Evaluate)
+ assert len(stmt.body.body.body.value.args) == 5
+ assert str(stmt.body.body.body.value.args[3]) == "(i*i)"
+ assert str(stmt.body.body.body.value.args[4]) == "(i + j)"
if __name__ == "__main__":
test_singleton()
test_pragma()
test_tensor_intrin()
+ test_tensor_intrin_scalar_params()
test_rfactor()
test_schedule_create()
test_reorder()