class Any : public ExprNode {
public:
void VisitAttrs(AttrVisitor* v) final {}
+ /*! \brief Convert to var. */
+ Var ToVar() const {
+ return Variable::make(Int(32), "any_dim");
+ }
TVM_DLL static Expr make();
using TNonComputational = bool;
/*!
+ * \brief Mark the operator whether output shape is data dependant.
+ */
+using TShapeDataDependant = bool;
+
+/*!
* \brief Computation description interface.
*
* \note This function have a special convention
using FShapeFunc = runtime::TypedPackedFunc<
Array<Tensor>(const Attrs& attrs,
const Array<Tensor>& inputs,
- const Array<Shape>& out_shapes)>;
+ const Array<IndexExpr>& out_ndims)>;
} // namespace relay
} // namespace tvm
from enum import Enum
-from .util import _internal_assert, _apply_indices
+from .util import _internal_assert
from . import calls
from . import util
from .preprocessor import determine_variable_usage
from ..tensor import Tensor, Operation
from .. import _api_internal as _tvm_internal
from .. import expr as _expr
-from .. import stmt as _stmt
from .. import make as _make
from .. import api as _api
from .. import ir_pass as _ir_pass
def concat_list_to_block(lst):
"""Concatenate a list of Python IR nodes to HalideIR Block"""
+ if not lst:
+ return util.make_nop()
n = len(lst)
if n == 1:
return lst[0]
body = lst[n - 1]
for i in range(1, n):
stmt = lst[n - 1 - i]
- if isinstance(stmt, _stmt.AssertStmt):
- body = _make.AssertStmt(stmt.condition, stmt.message, body)
- else:
- body = _make.Block(stmt, body)
+ body = _make.Block(stmt, body)
return body
ast.LtE : operator.le,
ast.Eq : operator.eq,
ast.NotEq : operator.ne,
- ast.And : _all,
- ast.Or : _any,
+ ast.And : _all,
+ ast.Or : _any,
}
to_pop = []
for key, val in self.usage.items():
_, level, _ = val
+ if key not in self.symbols:
+ # don't realize the symbols that are never visited
+ continue
if level != node:
continue
_internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key)
def visit_Attribute(self, node):
- _internal_assert(isinstance(node.value, ast.Name), \
- "For atrribute access, only both names are supported so far!")
buf = self.visit(node.value)
return getattr(buf, node.attr)
def visit_Subscript(self, node):
args = self.visit(node.slice)
- if isinstance(node.value, ast.Name):
- if node.value.id in self.closure_vars:
- args = ast.literal_eval(str(args))
- return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))
-
- buf = self.visit(node.value)
- if isinstance(buf, Array):
- for i in args:
- if isinstance(i, numbers.Integral):
- buf = buf[i]
- else:
- _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
- "All indices are supposed to be constants")
- buf = buf[i.value]
-
- return buf
-
- if isinstance(node.ctx, ast.Load):
- return _make.Call(buf.dtype, buf.name, args, \
- _expr.Call.Halide, buf.op, buf.value_index)
-
- return buf, args
-
- shape = self.visit(node.value)
- _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
- args = args[0]
- #TODO: maybe support non-constant value later?
- _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
- "So far only constant shape access supported!")
- return shape[args.value]
-
+ arr = self.visit(node.value)
+ if isinstance(arr, Array):
+ for i in args:
+ if isinstance(i, numbers.Integral):
+ arr = arr[i]
+ else:
+ _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \
+ "All indices are supposed to be constants")
+ arr = arr[i.value]
+ return arr
+ if isinstance(node.ctx, ast.Load):
+ return _make.Call(arr.dtype, arr.name, args,
+ _expr.Call.Halide, arr.op, arr.value_index)
+ return arr, args
def visit_With(self, node):
if sys.version_info[0] < 3:
def visit_If(self, node):
- cond = self.visit(node.test)
+ cond = _ir_pass.CanonicalSimplify(self.visit(node.test))
# Return no IfThenElse if proven
if isinstance(cond, _expr.UIntImm):
_name = node.target.id
if isinstance(for_type, tuple):
- low = _ir_pass.Simplify(low)
- ext = _ir_pass.Simplify(ext)
+ low = _ir_pass.CanonicalSimplify(low)
+ ext = _ir_pass.CanonicalSimplify(ext)
_internal_assert(isinstance(low, _expr.ConstExpr) and
isinstance(ext, _expr.ConstExpr), \
- "Const range should start from a const" + \
+ "Const range should start from a const " + \
"and iterate const times")
low, ext = low.value, ext.value
_internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem)))
return False
-
-def _apply_indices(value, indices):
- """Apply multidimensional index"""
- if indices:
- return _apply_indices(value[indices[0]], indices[1:])
- return value
The VM runtime.
"""
target = _update_target(target)
+ target_host = None if target_host == "" else target_host
+ if not target_host:
+ target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm"
+ target_host = tvm.target.create(target_host)
self._compile(mod, target, target_host)
return VirtualMachine(self._get_vm())
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#pylint: disable=invalid-name, unused-argument
+#pylint: disable=invalid-name, unused-argument, len-as-condition
"""Backend compiler related feature registration"""
from __future__ import absolute_import
import topi
-from .op import register_compute, register_schedule, register_pattern
+from .op import register_compute, register_schedule, register_pattern, register_shape_func
from .op import schedule_injective, OpPattern
+from ...hybrid import script
schedule_broadcast = schedule_injective
schedule_elemwise = schedule_injective
return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]
register_schedule("clip", schedule_elemwise)
+
+# shape func
+@script
+def _broadcast_shape_func(x, y, ndim):
+ out = output_tensor((ndim,), "int64")
+ if len(x.shape) == 0:
+ for i in const_range(ndim):
+ out[i] = y[i]
+ elif len(y.shape) == 0:
+ for i in const_range(ndim):
+ out[i] = x[i]
+ else:
+ ndim1 = x.shape[0]
+ ndim2 = y.shape[0]
+ for i in const_range(1, min(ndim1, ndim2)+1):
+ if x[ndim1-i] == y[ndim2-i]:
+ out[ndim-i] = x[ndim1-i]
+ elif x[ndim1-i] == 1:
+ out[ndim-i] = y[ndim2-i]
+ else:
+ assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % (
+ x[ndim1-i], y[ndim2-i])
+ out[ndim-i] = x[ndim1-i]
+ for i in const_range(min(ndim1, ndim2)+1, ndim+1):
+ if ndim1 >= ndim2:
+ out[ndim-i] = x[ndim1-i]
+ else:
+ out[ndim-i] = y[ndim2-i]
+ return out
+
+def broadcast_shape_func(attrs, inputs, out_ndims):
+ return [_broadcast_shape_func(*inputs, out_ndims[0])]
+
+register_shape_func("add", False, broadcast_shape_func)
+register_shape_func("subtract", False, broadcast_shape_func)
+register_shape_func("multiply", False, broadcast_shape_func)
+register_shape_func("divide", False, broadcast_shape_func)
+register_shape_func("mod", False, broadcast_shape_func)
+register_shape_func("logical_and", False, broadcast_shape_func)
+register_shape_func("logical_or", False, broadcast_shape_func)
+register_shape_func("equal", False, broadcast_shape_func)
+register_shape_func("not_equal", False, broadcast_shape_func)
+register_shape_func("less", False, broadcast_shape_func)
+register_shape_func("less_equal", False, broadcast_shape_func)
+register_shape_func("greater", False, broadcast_shape_func)
+register_shape_func("greater_equal", False, broadcast_shape_func)
# specific language governing permissions and limitations
# under the License.
"""Backend compiler related feature registration"""
-# pylint: disable=invalid-name,unused-argument
+# pylint: disable=invalid-name,unused-argument, len-as-condition
from __future__ import absolute_import
+from topi.util import get_const_int, get_const_tuple
from . import op as _reg
from ._reduce import _schedule_reduce
from .op import OpPattern
+from ...hybrid import script
+from ...api import convert
schedule_injective = _reg.schedule_injective
schedule_broadcast = _reg.schedule_injective
# layout_transform
_reg.register_schedule("layout_transform", schedule_injective)
_reg.register_pattern("layout_transform", OpPattern.INJECTIVE)
+
+# shape func
+@script
+def _arange_shape_func(start, stop, step):
+ out = output_tensor((1,), "int64")
+ out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0])))
+ return out
+
+@_reg.register_shape_func("arange", True)
+def arange_shape_func(attrs, inputs, _):
+ return [_arange_shape_func(*inputs)]
+
+@script
+def _concatenate_shape_func(inputs, axis):
+ ndim = inputs[0].shape[0]
+ out = output_tensor((ndim,), "int64")
+ for i in const_range(ndim):
+ if i != axis:
+ out[i] = inputs[0][i]
+ for j in const_range(1, len(inputs)):
+ assert out[i] == inputs[j][i], \
+ "Dims mismatch in the inputs of concatenate."
+ else:
+ out[i] = int64(0)
+ for j in const_range(len(inputs)):
+ out[i] += inputs[j][i]
+ return out
+
+@_reg.register_shape_func("concatenate", False)
+def concatenate_shape_func(attrs, inputs, _):
+ axis = get_const_int(attrs.axis)
+ return [_concatenate_shape_func(inputs, convert(axis))]
+
+@script
+def _reshape_shape_func(data_shape, newshape, ndim):
+ out = output_tensor((ndim,), "int64")
+ src_idx = 0
+ dst_idx = 0
+ infer_idx = -1
+ copy = False
+ skip = 0
+ for i in const_range(len(newshape)):
+ if skip > 0:
+ skip -= 1
+ elif newshape[i] > 0:
+ out[dst_idx] = int64(newshape[i])
+ src_idx += 1
+ dst_idx += 1
+ elif newshape[i] == 0:
+ out[dst_idx] = data_shape[src_idx]
+ src_idx += 1
+ dst_idx += 1
+ elif newshape[i] == -1:
+ assert infer_idx < 0, "One and only one dim can be inferred"
+ out[dst_idx] = int64(1)
+ infer_idx = i
+ dst_idx += 1
+ elif newshape[i] == -2:
+ copy = True
+ elif newshape[i] == -3:
+ assert data_shape.shape[0] - src_idx > 1, \
+ "Not enough dims in input shape for -3"
+ out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1]
+ src_idx += 2
+ dst_idx += 1
+ elif newshape[i] == -4:
+ assert len(newshape) - i > 2, "Not enough dims in new shape for -4"
+ if newshape[i+1] == -1:
+ assert newshape[i+2] != -1, "Split dims cannot both be -1."
+ out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2])
+ out[dst_idx+1] = int64(newshape[i+2])
+ else:
+ out[dst_idx] = int64(newshape[i+1])
+ if newshape[i+2] == -1:
+ out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1])
+ else:
+ out[dst_idx+1] = int64(newshape[i+2])
+ assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\
+ "Product of split dims doesn't match to input dim"
+ src_idx += 1
+ dst_idx += 2
+ skip = 2
+ else:
+ assert False, "Invalid special values in new shape"
+ if len(data_shape.shape) > 0:
+ # if data is not constant, we can then handle -1 and -2
+ if copy:
+ for i in range(src_idx, data_shape.shape[0]):
+ out[dst_idx] = data_shape[i]
+ dst_idx += 1
+ if infer_idx >= 0:
+ old_size = int64(1)
+ for i in const_range(data_shape.shape[0]):
+ old_size *= data_shape[i]
+ new_size = int64(1)
+ for i in const_range(out.shape[0]):
+ new_size *= out[i]
+ out[infer_idx] = old_size / new_size
+ return out
+
+@_reg.register_shape_func("reshape", False)
+def reshape_shape_func(attrs, inputs, out_ndims):
+ newshape = get_const_tuple(attrs.newshape)
+ return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])]
+
+@script
+def _take_no_axis_shape_func(indices_shape, out_ndim):
+ out = output_tensor((out_ndim,), "int64")
+ for i in const_range(out_ndim):
+ out[i] = indices_shape[i]
+ return out
+
+@script
+def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim):
+ out = output_tensor((out_ndim,), "int64")
+ for i in const_range(axis):
+ out[i] = data_shape[i]
+ if len(indices_shape.shape) == 0:
+ # indices is constant
+ for i in const_range(axis+1, len(data_shape)):
+ out[i-1] = data_shape[i]
+ else:
+ for i in const_range(len(indices_shape)):
+ out[axis+i] = indices_shape[i]
+ for i in const_range(axis+1, len(data_shape)):
+ out[len(indices_shape)+i-1] = data_shape[i]
+ return out
+
+@_reg.register_shape_func("take", False)
+def take_shape_func(attrs, inputs, out_ndims):
+ """
+ Shape function for take op.
+ """
+ if attrs.axis is None:
+ return [_take_no_axis_shape_func(inputs[1], out_ndims[0])]
+ else:
+ axis = get_const_int(attrs.axis)
+ data_ndim = int(inputs[0].shape[0])
+ if axis < 0:
+ axis += data_ndim
+ assert 0 <= axis < data_ndim
+ return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])]
"""
return _OpGetAttr(self, attr_name)
+ def set_attr(self, attr_name, value, plevel=10):
+ """Set attribute about the operator.
+
+ Parameters
+ ----------
+ attr_name : str
+ The attribute name
+
+ value : object
+ The attribute value
+
+ plevel : int
+ The priority level
+ """
+ _OpSetAttr(self, attr_name, value, plevel)
+
def get(op_name):
"""Get the Op for a given name
"""
return register(op_name, "FPrimalGradient", fgradient, level)
+def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
+ """Register operator shape function for an op.
+
+ Parameters
+ ----------
+ op_name : str
+ The name of the op.
+
+ data_dependant : bool
+ Whether the shape function depends on input data.
+
+ shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr])
+ -> shape_tensors: List<Tensor>
+ The function for computing the dynamic output shapes
+
+ level : int
+ The priority level
+ """
+ get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
+ return register(op_name, "FShapeFunc", shape_func, level)
_init_api("relay.op", __name__)
}
Array<Tensor> HybridOpNode::InputTensors() const {
- return inputs;
+ // Because input tensors could be potentially inlined into hybrid scripts,
+ // we need to check if all input tensors are used in the body.
+ std::unordered_set<Tensor> orig_inputs;
+ for (auto t : inputs) {
+ orig_inputs.insert(t);
+ }
+ std::unordered_set<Tensor> visited;
+ Array<Tensor> curr_inputs;
+ ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) {
+ const ir::Call *call = n.as<ir::Call>();
+ if (call != nullptr && call->func.defined()) {
+ Tensor t = Operation(call->func.node_).output(call->value_index);
+ if (orig_inputs.count(t) && !visited.count(t)) {
+ curr_inputs.push_back(t);
+ visited.insert(t);
+ }
+ }
+ });
+ return curr_inputs;
}
Operation HybridOpNode::ReplaceInputs(
arith::Analyzer* analyzer,
const std::unordered_map<const Variable*, IntSet> &dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const {
- for (Tensor t : this->inputs) {
+ auto curr_inputs = InputTensors();
+ for (Tensor t : curr_inputs) {
auto it = out_dom_map->find(t);
if (it == out_dom_map->end()) continue;
TensorDom &dom = it->second;
outputs[i]->dtype);
f_push_bind(buffer, stage->op.output(i));
}
- for (int i = static_cast<int>(inputs.size()) - 1; i >= 0; --i) {
- Buffer buffer = decl_buffer(
- inputs[i]->shape,
- inputs[i]->dtype);
- f_push_bind(buffer, inputs[i]);
+ auto curr_inputs = InputTensors();
+ for (int i = static_cast<int>(curr_inputs.size()) - 1; i >= 0; --i) {
+ Buffer buffer = decl_buffer(curr_inputs[i]->shape, curr_inputs[i]->dtype);
+ f_push_bind(buffer, curr_inputs[i]);
}
std::unordered_map<Tensor, Tensor> rmap;
* tensors have the same names as the operation produces them.
* 2. Once OpNode is wrapped up by an Operation node, it is finalized.
* Later access will be from a const OpNode*.
- * This is a chiken-egg paradox. It is impossible to put the output
+ * This is a chicken-egg paradox. It is impossible to put the output
* tensors into the function body without forming the op node. The
* function body is immutable after the node is formed.
*
using runtime::StorageRank;
using runtime::StorageScope;
-// Find a linear pattern of storage acess
+// Find a linear pattern of storage access
// Used for liveness analysis.
// Composite scopes(loop/thread_launch/IfThen) is represented by two points:
// before_scope -> scope_body -> after_scope
VisitNewScope(op);
}
+ void Visit_(const AssertStmt* op) final {
+ VisitNewScope(op);
+ }
+
// linearized access sequence.
std::vector<StmtEntry> linear_seq_;
// The storage scope of each buffer
#include <limits>
#include <mutex>
#include <functional>
+#include <vector>
#include <unordered_map>
+#include "../ir/type_functor.h"
#include "compile_engine.h"
namespace tvm {
return CCacheKey(n);
}
+struct IsDynamicVisitor : public TypeVisitor {
+ bool is_dyn{false};
+ void VisitType_(const TensorTypeNode* tt) {
+ for (auto dim : tt->shape) {
+ if (dim.as<Any>()) {
+ is_dyn = true;
+ break;
+ }
+ }
+ }
+};
+
+bool IsDynamic(const Type& ty) {
+ IsDynamicVisitor v;
+ v.VisitType(ty);
+ return v.is_dyn;
+}
+
+Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
+ // for now, we always use int32 shape when possible
+ // even if the result of shape inference becomes int64.
+ Array<IndexExpr> res;
+ for (IndexExpr val : shape) {
+ const int64_t* pval = as_const_int(val);
+ if (pval != nullptr) {
+ CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
+ CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
+ res.push_back(ir::IntImm::make(Int(32), *pval));
+ } else if (val->is_type<ir::Any>()) {
+ res.push_back(val.as<ir::Any>()->ToVar());
+ } else {
+ res.push_back(val);
+ }
+ }
+ return res;
+}
+
// The getter to get schedule from compile engine.
// Get schedule from functor.
class ScheduleGetter :
explicit ScheduleGetter(Target target)
: target_(target) {}
- Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
- // for now, we always use int32 shape when possible
- // even if the result of shape inference becomes int64.
- Array<IndexExpr> res;
- for (IndexExpr val : shape) {
- const int64_t* pval = as_const_int(val);
- if (pval != nullptr) {
- CHECK_LE(pval[0], std::numeric_limits<int32_t>::max());
- CHECK_GE(pval[0], std::numeric_limits<int32_t>::min());
- res.push_back(ir::IntImm::make(Int(32), *pval));
- } else {
- res.push_back(val);
- }
- }
- return res;
- }
-
std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule =
Op::GetAttr<FTVMSchedule>("FTVMSchedule");
const auto* tuple_type = param->type_as<TupleTypeNode>();
for (Type field : tuple_type->fields) {
const auto* ttype = field.as<TensorTypeNode>();
+ // TODO(@icemelon): Allow recursive tuple
CHECK(ttype != nullptr);
tvm::Tensor tensor = tvm::placeholder(
GetShape(ttype->shape), ttype->dtype);
Array<Operation> scalars_;
};
+// Creates shape function from functor.
+class MakeShapeFunc : public ExprFunctor<Array<Tensor>(const Expr&)> {
+ public:
+ MakeShapeFunc() {}
+
+ std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
+ for (auto param : prim_func->params) {
+ param_states_[param] = kNoNeed;
+ Array<tvm::Tensor> data_inputs;
+ Array<tvm::Tensor> shape_inputs;
+
+ auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) {
+ // Add data placeholder
+ Shape shape = GetShape(ttype->shape);
+ tvm::Tensor data_tensor = tvm::placeholder(shape, ttype->dtype);
+ data_inputs.push_back(data_tensor);
+ // Add shape placeholder
+ int64_t ndim = shape.size();
+ Shape sshape;
+ if (ndim > 0) {
+ sshape.push_back(tvm::Integer(ndim));
+ }
+ tvm::Tensor shape_tensor = tvm::placeholder(sshape, Int(64));
+ shape_inputs.push_back(shape_tensor);
+ };
+
+ if (const auto *ttype = param->checked_type().as<TensorTypeNode>()) {
+ add_placeholder(ttype);
+ } else {
+ // flatten tuple of tensor type.
+ const auto *tuple_type = param->type_as<TupleTypeNode>();
+ // TODO(@icemelon): Support recursive tuple
+ CHECK(tuple_type);
+ for (Type field : tuple_type->fields) {
+ const auto *ttype = field.as<TensorTypeNode>();
+ CHECK(ttype);
+ add_placeholder(ttype);
+ }
+ }
+ param_data_[param] = data_inputs;
+ param_shapes_[param] = shape_inputs;
+ }
+ readable_name_stream_ << "shape_func";
+ auto cache_node = make_node<CachedFuncNode>();
+ cache_node->outputs = VisitExpr(prim_func->body);
+ auto candidate_name = readable_name_stream_.str();
+ constexpr static size_t kMaxFuncNameLength = 80;
+ if (candidate_name.size() > kMaxFuncNameLength) {
+ std::stringstream truncated_name;
+ truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
+ truncated_name << "_" << std::hash<std::string>{}(candidate_name) << "_";
+ candidate_name = truncated_name.str();
+ }
+ cache_node->func_name = candidate_name;
+
+ // set inputs
+ for (auto param : prim_func->params) {
+ int state = param_states_[param];
+ cache_node->shape_func_param_states.push_back(IntImm::make(Int(32), state));
+ if (state & kNeedInputData) {
+ for (auto t : param_data_[param]) {
+ cache_node->inputs.push_back(t);
+ }
+ }
+ if (state & kNeedInputShape) {
+ for (auto t : param_shapes_[param]) {
+ cache_node->inputs.push_back(t);
+ }
+ }
+ }
+
+ CachedFunc cfunc(cache_node);
+ // generate schedule for shape func
+ Array<Operation> out_ops;
+ for (auto t : cache_node->outputs) {
+ out_ops.push_back(t->op);
+ }
+ auto schedule = create_schedule(out_ops);
+ tvm::schedule::AutoInlineInjective(schedule);
+ for (const auto& scalar : scalars_) {
+ auto scalar_op = scalar->op;
+ if (schedule->Contain(scalar_op)) {
+ schedule[scalar_op].compute_inline();
+ }
+ }
+ return std::make_pair(schedule, cfunc);
+ }
+
+ Array<Tensor> VisitExpr(const Expr& expr) {
+ auto it = memo_.find(expr);
+ if (it != memo_.end()) {
+ return it->second;
+ } else {
+ Array<Tensor> res = ExprFunctor::VisitExpr(expr);
+ if (expr.as<VarNode>() == nullptr) {
+ // Do not memoize vars because shape functions could use either the data
+ // or the shape of a var each time.
+ memo_[expr] = res;
+ }
+ return res;
+ }
+ }
+
+ Array<Tensor> VisitExpr_(const VarNode* var_node) final {
+ auto var = GetRef<Var>(var_node);
+ auto it = param_states_.find(var);
+ if (it == param_states_.end()) {
+ LOG(FATAL) << "Free variable " << var->name_hint();
+ return {};
+ } else {
+ CHECK(data_dependants_.size());
+ bool data_dependant = data_dependants_.back();
+ if (data_dependant) {
+ param_states_[var] |= kNeedInputData;
+ return param_data_[var];
+ } else {
+ param_states_[var] |= kNeedInputShape;
+ return param_shapes_[var];
+ }
+ }
+ }
+
+ Array<Tensor> VisitExpr_(const ConstantNode* op) final {
+ CHECK(data_dependants_.size());
+ CHECK(op->is_scalar());
+ bool data_dependant = data_dependants_.back();
+ if (data_dependant) {
+ void* data = op->data->data;
+ DataType dtype = TVMType2Type(op->data->dtype);
+ Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+ if (dtype == Int(32)) {
+ return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+ } else if (dtype == Int(64)) {
+ return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+ } else if (dtype == Float(32)) {
+ return make_const(dtype, static_cast<const float*>(data)[0]);
+ } else if (dtype == Float(64)) {
+ return make_const(dtype, static_cast<const double*>(data)[0]);
+ } else if (dtype == Bool()) {
+ return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+ } else {
+ LOG(FATAL) << "not handled";
+ return tvm::Expr();
+ }
+ }, "data_const", topi::kBroadcast);
+ scalars_.push_back(value);
+ return {value};
+ } else {
+ Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+ return make_const(Int(64), 0);
+ }, "shape_const", topi::kBroadcast);
+ scalars_.push_back(value);
+ return {value};
+ }
+ }
+
+ Array<Tensor> VisitExpr_(const CallNode* call_node) final {
+ static auto fshape_func = Op::GetAttr<FShapeFunc>("FShapeFunc");
+ static auto tshape_data_dependant = Op::GetAttr<TShapeDataDependant>(
+ "TShapeDataDependant");
+ CHECK(call_node->op.as<OpNode>())
+ << "Primitive function only allows call into primitive ops";
+ Op op = Downcast<Op>(call_node->op);
+ CHECK(data_dependants_.empty() || !data_dependants_.back())
+ << "Error in op fusion: output of the shape func is fed to a "
+ << "data-dependant shape func";
+ CHECK_GT(fshape_func.count(op), 0)
+ << "Internal error, cannot find ShapeFunc for " << op->name;
+ CHECK_GT(tshape_data_dependant.count(op), 0)
+ << "Internal error, cannot find TShapeDataDependant for " << op->name;
+
+ data_dependants_.push_back(tshape_data_dependant[op]);
+ // Visit all inputs
+ Array<Tensor> inputs;
+ int count_tuple = 0;
+ for (Expr arg : call_node->args) {
+ if (arg->checked_type().as<TupleTypeNode>()) {
+ ++count_tuple;
+ }
+ for (Tensor tensor : VisitExpr(arg)) {
+ inputs.push_back(tensor);
+ }
+ }
+ if (count_tuple) {
+ CHECK_EQ(call_node->args.size(), 1U)
+ << "Only allow function with a single tuple input";
+ }
+ // Get output ndims
+ auto ret_type = call_node->checked_type();
+ Array<IndexExpr> out_ndims;
+ if (const auto* ttype = ret_type.as<TensorTypeNode>()) {
+ out_ndims.push_back(IntImm::make(Int(32), ttype->shape.size()));
+ } else {
+ auto rtype = ret_type.as<TupleTypeNode>();
+ // TODO(@icemelon): Allow recursive tuple
+ CHECK(rtype);
+ for (size_t i = 0; i < rtype->fields.size(); ++i) {
+ auto ttype = rtype->fields[i].as<TensorTypeNode>();
+ CHECK(ttype);
+ out_ndims.push_back(IntImm::make(Int(32), ttype->shape.size()));
+ }
+ }
+ // Call shape function
+ auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims);
+ data_dependants_.pop_back();
+ readable_name_stream_ << "_" << op->name;
+ return outputs;
+ }
+
+ Array<Tensor> VisitExpr_(const FunctionNode* op) final {
+ LOG(FATAL) << "Do not support sub function";
+ return Array<Tensor>();
+ }
+
+ Array<Tensor> VisitExpr_(const LetNode* op) final {
+ Array<Tensor> val = VisitExpr(op->value);
+ CHECK(!memo_.count(op->var));
+ memo_[op->var] = val;
+ return VisitExpr(op->body);
+ }
+
+ Array<Tensor> VisitExpr_(const TupleNode* op) final {
+ Array<Tensor> fields;
+ for (Expr field : op->fields) {
+ CHECK(field->checked_type().as<TensorTypeNode>())
+ << "Only allow Tuple of Tensor";
+ Array<Tensor> res = VisitExpr(field);
+ CHECK_EQ(res.size(), 1);
+ fields.push_back(res[0]);
+ }
+ return fields;
+ }
+
+ private:
+ /*! \brief String stream for function name */
+ std::ostringstream readable_name_stream_;
+ /*! \brief Map from parameter to its shape function usage state */
+ std::unordered_map<Expr, int, NodeHash, NodeEqual> param_states_;
+ /*! \brief Map from parameter to list of data placeholder */
+ std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> param_data_;
+ /*! \brief Map from parameter to list of shape placeholder */
+ std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> param_shapes_;
+ /*! \brief Memoized visit result */
+ std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
+ /*! \brief Stack of data dependencies for shape function */
+ std::vector<bool> data_dependants_;
+ /*! \brief Scalars used in the shape function */
+ Array<Tensor> scalars_;
+};
class CompileEngineImpl : public CompileEngineNode {
public:
}
return value->packed_func;
}
+
+ CachedFunc LowerShapeFunc(const CCacheKey& key) final {
+ return LowerShapeFuncInternal(key)->cached_func;
+ }
+
void Clear() final {
cache_.clear();
}
value->cached_func = CachedFunc(cache_node);
return value;
}
+ // implement lowered shape func
+ CCacheValue LowerShapeFuncInternal(const CCacheKey& key) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ CCacheValue value;
+ auto it = shape_func_cache_.find(key);
+ if (it != shape_func_cache_.end()) {
+ it->second->use_count += 1;
+ if (it->second->cached_func.defined()) return it->second;
+ value = it->second;
+ } else {
+ value = CCacheValue(make_node<CCacheValueNode>());
+ value->use_count = 0;
+ shape_func_cache_[key] = value;
+ }
+ // Enforce use the target.
+ With<Target> target_scope(key->target);
+
+ CHECK(!value->cached_func.defined());
+ auto spair = MakeShapeFunc().Create(key->source_func);
+ auto cache_node = make_node<CachedFuncNode>(
+ *(spair.second.operator->()));
+ cache_node->func_name = GetUniqueName(cache_node->func_name);
+ cache_node->target = key->target;
+
+ Array<Tensor> all_args = cache_node->inputs;
+ for (Tensor arg : cache_node->outputs) {
+ all_args.push_back(arg);
+ }
+ tvm::BuildConfig bcfg = BuildConfig::Create();
+ std::unordered_map<Tensor, Buffer> binds;
+ cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
+ value->cached_func = CachedFunc(cache_node);
+ return value;
+ }
/*!
* \brief Get unique name from name.
* \param name The orginal name.
std::unordered_map<std::string, int> name_map_;
/*! \brief internal compiler cache */
std::unordered_map<CCacheKey, CCacheValue> cache_;
+ /*! \brief internal compiler cache for shape funcs */
+ std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;
};
/*! \brief The global compile engine */
namespace tvm {
namespace relay {
+/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */
+enum ShapeFuncParamState {
+ kNoNeed = 0,
+ kNeedInputData = 1,
+ kNeedInputShape = 2,
+ kNeedBoth = 3,
+};
+
/*! \brief Node container to represent a cached function. */
struct CachedFuncNode : public Node {
/* \brief compiled target */
tvm::Array<Tensor> outputs;
/*! \brief The lowered functions to support the function. */
tvm::Array<tvm::LoweredFunc> funcs;
+ /*! \brief Parameter usage states in the shape function. */
+ tvm::Array<Integer> shape_func_param_states;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("target", &target);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("funcs", &funcs);
+ v->Visit("shape_func_param_states", &shape_func_param_states);
}
static constexpr const char* _type_key = "relay.CachedFunc";
* \return The result.
*/
virtual PackedFunc JIT(const CCacheKey& key) = 0;
+ /*!
+ * \brief Lower the shape function.
+ * \param key The key to the cached function.
+ * \return The result.
+ */
+ virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
/*! \brief clear the cache. */
virtual void Clear() = 0;
TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node);
};
-/*! \brier cache entry used in compile engine */
+/*! \brief cache entry used in compile engine */
class CompileEngine : public NodeRef {
public:
CompileEngine() {}
TVM_DLL static const CompileEngine& Global();
};
+/*!
+ * \brief Check if the type is dynamic.
+ * \param ty The type to be checked.
+ * \return The result.
+ */
+bool IsDynamic(const Type& ty);
+
// implementations
inline size_t CCacheKeyNode::Hash() const {
if (hash_ != 0) return hash_;
return TupleValueNode::make(values);
}
- // TODO(@jroesch): this doesn't support mututal letrec
+ // TODO(@jroesch): this doesn't support mutual letrec
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
return MakeClosure(func);
}
- Value InvokePrimitiveOp(Function func,
+ Array<Shape> ComputeDynamicShape(const Function& func,
+ const Array<Value>& args) {
+ auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
+ auto cfunc = engine_->LowerShapeFunc(key);
+ size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
+
+ std::vector<TVMValue> values(arity);
+ std::vector<int> codes(arity);
+ TVMArgsSetter setter(values.data(), codes.data());
+ std::vector<NDArray> inputs(cfunc->inputs.size());
+ std::vector<NDArray> outputs(cfunc->outputs.size());
+
+ DLContext cpu_ctx;
+ cpu_ctx.device_type = kDLCPU;
+ cpu_ctx.device_id = 0;
+
+ auto fset_input = [&](size_t i, Value val, bool need_shape) {
+ const TensorValueNode* tv = val.as<TensorValueNode>();
+ CHECK(tv != nullptr) << "expect Tensor argument";
+ if (need_shape) {
+ int64_t ndim = tv->data.Shape().size();
+ NDArray shape_arr;
+ if (ndim == 0) {
+ shape_arr = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
+ } else {
+ shape_arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
+ int64_t* data = reinterpret_cast<int64_t*>(shape_arr->data);
+ for (auto j = 0; j < ndim; ++j) {
+ data[j] = tv->data.Shape()[j];
+ }
+ }
+ inputs[i] = shape_arr;
+ setter(i, shape_arr);
+ } else {
+ auto arr = tv->data.CopyTo(cpu_ctx);
+ inputs[i] = arr;
+ setter(i, arr);
+ }
+ };
+
+ size_t arg_counter = 0;
+ for (size_t i = 0; i < args.size(); ++i) {
+ auto arg = args[i];
+ auto param = func->params[i];
+ int state = cfunc->shape_func_param_states[i]->value;
+ if (arg.as<TensorValueNode>()) {
+ if (state & kNeedInputData) {
+ fset_input(arg_counter++, arg, false);
+ }
+ if (state & kNeedInputShape) {
+ fset_input(arg_counter++, arg, true);
+ }
+ } else {
+ const TupleValueNode* tuple = arg.as<TupleValueNode>();
+ CHECK(tuple != nullptr);
+ if (state & kNeedInputData) {
+ for (size_t i = 0; i < tuple->fields.size(); ++i) {
+ fset_input(arg_counter++, tuple->fields[i], false);
+ }
+ }
+ if (state & kNeedInputShape) {
+ for (size_t i = 0; i < tuple->fields.size(); ++i) {
+ fset_input(arg_counter++, tuple->fields[i], true);
+ }
+ }
+ }
+ }
+ CHECK_EQ(arg_counter, cfunc->inputs.size())
+ << "Shape function input sizes mismatch";
+
+ auto fset_shape_output = [&](size_t i, Type val_type) {
+ // TODO(@icemelon): allow recursive tuple
+ const TensorTypeNode* rtype = val_type.as<TensorTypeNode>();
+ CHECK(rtype != nullptr);
+ int64_t ndim = rtype->shape.size();
+ auto arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
+ outputs[i] = arr;
+ setter(arg_counter + i, arr);
+ };
+
+ auto ret_type = func->body->checked_type();
+ size_t out_cnt = 0;
+ if (auto rtype = ret_type.as<TupleTypeNode>()) {
+ out_cnt = rtype->fields.size();
+ for (size_t i = 0; i < out_cnt; ++i) {
+ fset_shape_output(i, rtype->fields[i]);
+ }
+ } else {
+ out_cnt = 1;
+ auto tt = Downcast<TensorType>(ret_type);
+ fset_shape_output(0, tt);
+ }
+ CHECK_EQ(cfunc->outputs.size(), out_cnt)
+ << "Shape function output sizes mismatch";
+
+ PackedFunc shape_func;
+ TVMRetValue rv;
+ if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
+ tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target);
+ shape_func = m.GetFunction(cfunc->func_name);
+ } else {
+ LOG(FATAL) << "relay.backend.build is not registered";
+ }
+ shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv);
+
+ // Get output shapes
+ Array<Shape> out_shapes;
+ for (auto out_tensor : outputs) {
+ int64_t* shape_data = reinterpret_cast<int64_t*>(out_tensor->data);
+ Shape out_shape;
+ for (int i = 0; i < out_tensor->shape[0]; ++i) {
+ out_shape.push_back(tvm::Integer(shape_data[i]));
+ }
+ out_shapes.push_back(out_shape);
+ }
+ return out_shapes;
+ }
+
+ Value InvokePrimitiveOp(const Function& func,
const Array<Value>& args) {
auto call_node = func->body.as<CallNode>();
return out_tensor;
};
+ Array<Shape> out_shapes;
+ auto ret_type = func->body->checked_type();
+ bool is_dyn = IsDynamic(func->checked_type());
+ if (call_node->op == Op::Get("shape_of")) {
+ // The output shape of shape_of must be static since Relay doesn't support
+ // dynamic rank tensors.
+ is_dyn = false;
+ }
+
+ if (is_dyn) {
+ CHECK(func->IsPrimitive());
+ out_shapes = ComputeDynamicShape(func, args);
+ }
+
PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_));
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
+ CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
Array<Value> fields;
for (size_t i = 0; i < rtype->fields.size(); ++i) {
- fields.push_back(fset_output(i, rtype->fields[i]));
+ if (is_dyn) {
+ auto sh = out_shapes[i];
+ auto tt = Downcast<TensorType>(rtype->fields[i]);
+ fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype)));
+ } else {
+ fields.push_back(fset_output(i, rtype->fields[i]));
+ }
}
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return TupleValueNode::make(fields);
} else {
- Value out_tensor = fset_output(0, func->body->checked_type());
+ Value out_tensor;
+ if (is_dyn) {
+ CHECK_EQ(out_shapes.size(), 1);
+ auto sh = out_shapes[0];
+ auto tt = Downcast<TensorType>(ret_type);
+ out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype));
+ } else {
+ out_tensor = fset_output(0, ret_type);
+ }
packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv);
return out_tensor;
}
* \brief A compiler from relay::Module to the VM byte code.
*/
+#include <tvm/operation.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h>
#include <tvm/logging.h>
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
+#include <topi/tags.h>
+#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
-// Compute the constant pool, i.e a mapping from Constant node to constant index.
-struct ConstantPool : ExprVisitor {
- std::set<GlobalVar> visited;
- Module module;
- ConstMap const_map;
- ConstTensorShapeMap const_tensor_shape_map;
-
- size_t index;
-
- explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {}
-
- void VisitExpr_(const GlobalVarNode* var_node) {
- auto gvar = GetRef<GlobalVar>(var_node);
- if (visited.find(gvar) == visited.end()) {
- visited.insert(gvar);
- this->VisitExpr(this->module->Lookup(gvar));
- }
- }
-
- void VisitExpr_(const ConstantNode* const_node) {
- auto konst = GetRef<Constant>(const_node);
- auto it = this->const_map.find(konst);
- if (it == this->const_map.end()) {
- this->const_map.insert({konst, index++});
- }
- }
-};
-
-std::tuple<ConstMap, ConstTensorShapeMap> LayoutConstantPool(const Module& module) {
- auto cp = ConstantPool(module);
- for (auto& func : module->functions) {
- cp.VisitExpr(func.first);
- }
- return std::make_tuple(cp.const_map, cp.const_tensor_shape_map);
-}
-
void InstructionPrint(std::ostream& os, const Instruction& instr);
// Represent a runtime object that's going to be matched by pattern match expressions
class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
public:
- VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets)
+ VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
: last_register_(0),
registers_num_(0),
engine_(CompileEngine::Global()),
context_(context),
- targets_(targets) {}
+ targets_(targets),
+ target_host_(target_host) {}
VMFunction Compile(const GlobalVar& var, const Function& func) {
size_t i = 0;
}
void VisitExpr_(const ConstantNode* const_node) {
- auto rconst = GetRef<Constant>(const_node);
- auto it = this->context_->const_map.find(rconst);
- CHECK(it != this->context_->const_map.end());
- Emit(Instruction::LoadConst(it->second, NewRegister()));
+ size_t konst_idx = context_->constants.size();
+ context_->constants.push_back(const_node->data);
+ Emit(Instruction::LoadConst(konst_idx, NewRegister()));
}
void VisitExpr_(const VarNode* var_node) {
}
void VisitExpr_(const LetNode* let_node) {
- DLOG(INFO) << let_node->value;
+ DLOG(INFO) << AsText(let_node->value);
this->VisitExpr(let_node->value);
var_register_map_.insert({let_node->var, this->last_register_});
this->VisitExpr(let_node->body);
this->last_register_ = true_register;
}
- Instruction AllocTensorFromType(const TensorTypeNode* ttype) {
- TVMType dltype = Type2TVMType(ttype->dtype);
- auto tensor_type = GetRef<TensorType>(ttype);
+ Index EmitGetShape(const TensorTypeNode* ttype, Index reg) {
+ bool const_shape = true;
std::vector<int64_t> shape;
- for (auto dim : tensor_type->shape) {
- shape.push_back(Downcast<tvm::Integer>(dim)->value);
+ for (auto dim : ttype->shape) {
+ if (auto kdim = dim.as<IntImm>()) {
+ shape.push_back(kdim->value);
+ } else {
+ const_shape = false;
+ }
+ }
+ if (const_shape) {
+ int64_t ndim = shape.size();
+ DLContext cpu_ctx;
+ cpu_ctx.device_type = kDLCPU;
+ cpu_ctx.device_id = 0;
+ NDArray shape_tensor;
+ if (ndim == 0) {
+ shape_tensor = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx);
+ } else {
+ shape_tensor = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx);
+ int64_t* dims = reinterpret_cast<int64_t*>(shape_tensor->data);
+ for (size_t i = 0; i < shape.size(); ++i) {
+ dims[i] = shape[i];
+ }
+ }
+ size_t konst_idx = context_->constants.size();
+ context_->constants.push_back(shape_tensor);
+ Emit(Instruction::LoadConst(konst_idx, NewRegister()));
+ return last_register_;
+ }
+ // For dynamic shape, we need insert shape_of op to get its shape at runtime
+ auto attrs = make_node<ShapeOfAttrs>();
+ attrs->dtype = Int(64);
+ static const Op& op = Op::Get("shape_of");
+ auto input = VarNode::make("input", GetRef<Type>(ttype));
+ auto expr = CallNode::make(op, {input}, Attrs(attrs), {});
+ auto func = FunctionNode::make({input}, expr, IncompleteTypeNode::make(Kind::kType), {});
+ auto mod = ModuleNode::make({}, {});
+ auto main_gv = GlobalVarNode::make("main");
+ mod->Add(main_gv, func);
+ func = mod->Lookup(main_gv);
+
+ // shape_of op has to be run on the host target
+ // TODO(@icemelon9): handle heterogeneous target, such as cuda
+ auto key = CCacheKeyNode::make(func, target_host_);
+ auto cfunc = engine_->Lower(key);
+ auto op_index = -1;
+ if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
+ op_index = context_->cached_funcs.size();
+ context_->cached_funcs.push_back(cfunc);
+ context_->seen_funcs[cfunc->funcs[0]] = op_index;
+ } else {
+ op_index = context_->seen_funcs[cfunc->funcs[0]];
+ }
+ std::vector<Index> arg_regs{reg};
+ int64_t ndim = ttype->shape.size();
+ if (ndim == 0) {
+ Emit(Instruction::AllocTensor({}, Int(64), NewRegister()));
+ } else {
+ Emit(Instruction::AllocTensor({ndim}, Int(64), NewRegister()));
+ }
+ Index shape_reg = last_register_;
+ arg_regs.push_back(shape_reg);
+ Emit(Instruction::InvokePacked(op_index, 2, 1, arg_regs));
+ return shape_reg;
+ }
+
+ std::vector<Index> EmitShapeFunc(const Type& ret_type, const Function& func,
+ const std::vector<Index>& unpacked_arg_regs) {
+ // Find the mapping from params to registers
+ int idx = 0;
+ std::vector<std::vector<Index>> param_regs;
+ std::vector<std::vector<const TensorTypeNode*>> param_types;
+ for (auto param : func->params) {
+ auto ty = param->checked_type();
+ std::vector<Index> regs;
+ std::vector<const TensorTypeNode*> types;
+ if (auto ttype = ty.as<TensorTypeNode>()) {
+ regs.push_back(unpacked_arg_regs[idx++]);
+ types.push_back(ttype);
+ } else if (const auto tuple_ty = ret_type.as<TupleTypeNode>()) {
+ for (size_t j = 0; j < tuple_ty->fields.size(); ++j, ++idx) {
+ regs.push_back(unpacked_arg_regs[idx]);
+ auto ttype = tuple_ty->fields[j].as<TensorTypeNode>();
+ CHECK(ttype);
+ types.push_back(ttype);
+ }
+ } else {
+ LOG(FATAL) << "unsupported parameter type " << ty;
+ }
+ param_regs.push_back(regs);
+ param_types.push_back(types);
+ }
+
+ // Lower shape function
+ auto key = CCacheKeyNode::make(func, target_host_);
+ auto cfunc = engine_->LowerShapeFunc(key);
+ int op_index = -1;
+ if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
+ op_index = context_->cached_funcs.size();
+ context_->cached_funcs.push_back(cfunc);
+ context_->seen_funcs[cfunc->funcs[0]] = op_index;
+ } else {
+ op_index = context_->seen_funcs[cfunc->funcs[0]];
+ }
+
+ // Prepare input and output registers
+ std::vector<Index> shape_func_args;
+ std::vector<Index> shape_regs;
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ int state = cfunc->shape_func_param_states[i]->value;
+ if (state & kNeedInputData) {
+ for (auto reg : param_regs[i]) {
+ // TODO(@icemelon9): Need to copy data here for heterogeneous exec
+ shape_func_args.push_back(reg);
+ }
+ }
+ if (state & kNeedInputShape) {
+ for (size_t j = 0; j < param_regs[i].size(); ++j) {
+ shape_func_args.push_back(EmitGetShape(param_types[i][j], param_regs[i][j]));
+ }
+ }
+ }
+ for (auto t : cfunc->outputs) {
+ int64_t ndim = t->shape[0].as<IntImm>()->value;
+ Emit(Instruction::AllocTensor({ndim}, t->dtype, NewRegister()));
+ shape_func_args.push_back(last_register_);
+ shape_regs.push_back(last_register_);
+ }
+
+ int arity = shape_func_args.size();
+ int ret_count = shape_regs.size();
+ Emit(Instruction::InvokePacked(op_index, arity, ret_count, shape_func_args));
+
+ // Alloc return tensors given the shape regs
+ std::vector<DataType> ret_dtypes;
+ if (const auto* tuple_type = ret_type.as<TupleTypeNode>()) {
+ for (auto field : tuple_type->fields) {
+ const TensorTypeNode* tty = field.as<TensorTypeNode>();
+ CHECK(tty);
+ ret_dtypes.push_back(tty->dtype);
+ }
+ } else {
+ auto tty = ret_type.as<TensorTypeNode>();
+ CHECK(tty);
+ ret_dtypes.push_back(tty->dtype);
+ }
+ std::vector<Index> ret_regs;
+ for (size_t i = 0; i < shape_regs.size(); ++i) {
+ Emit(Instruction::AllocTensorReg(shape_regs[i], ret_dtypes[i], NewRegister()));
+ ret_regs.push_back(last_register_);
+ }
+ return ret_regs;
+ }
+
+ std::vector<Index> AllocReturnType(const Type& ret_type, const Function& func,
+ const std::vector<Index>& unpacked_arg_regs) {
+ auto op = func->body.as<CallNode>()->op;
+ // 1. If either func param types or ret type is dynamic, we need to insert
+ // shape func to perform type checking at runtime.
+ // 2. We skip the shape_of function since currently Relay doesn't support
+ // dynamic rank tensor.
+ if (op != Op::Get("shape_of") && IsDynamic(func->checked_type())) {
+ return EmitShapeFunc(ret_type, func, unpacked_arg_regs);
}
- return Instruction::AllocTensor(shape, dltype, NewRegister());
+ std::vector<Index> ret_regs;
+ auto alloc_tensor = [&](const TensorTypeNode* ttype) {
+ const TensorType& tensor_type = GetRef<TensorType>(ttype);
+ std::vector<int64_t> shape;
+ for (auto dim : tensor_type->shape) {
+ shape.push_back(Downcast<tvm::Integer>(dim)->value);
+ }
+ Emit(Instruction::AllocTensor(shape, Type2TVMType(tensor_type->dtype), NewRegister()));
+ ret_regs.push_back(last_register_);
+ };
+ if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
+ alloc_tensor(ttype);
+ } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
+ for (auto field : ttype->fields) {
+ alloc_tensor(field.as<TensorTypeNode>());
+ }
+ } else {
+ LOG(FATAL) << "Unsupported return value type";
+ }
+ return ret_regs;
}
void EmitInvokePrimitive(const Function& func,
- const std::vector<Index>& args_registers,
+ const std::vector<Index>& arg_registers,
const Type& ret_type) {
std::vector<Index> unpacked_arg_regs;
std::vector<Instruction> allocs;
// Arity calculation must flatten tuples.
size_t arity = 0;
- CHECK_EQ(func->params.size(), args_registers.size());
+ CHECK_EQ(func->params.size(), arg_registers.size());
for (size_t i = 0; i < func->params.size(); i++) {
auto ty = func->params[i]->checked_type();
if (ty.as<TensorTypeNode>()) {
- unpacked_arg_regs.push_back(args_registers[i]);
+ unpacked_arg_regs.push_back(arg_registers[i]);
arity += 1;
} else if (auto tuple_ty = ty.as<TupleTypeNode>()) {
for (size_t f = 0; f < tuple_ty->fields.size(); f++) {
<< "only supports non-nested tuples currently "
<< "found " << field;
auto dst = NewRegister();
- Emit(Instruction::GetField(args_registers[i], f, dst));
+ Emit(Instruction::GetField(arg_registers[i], f, dst));
unpacked_arg_regs.push_back(dst);
}
arity += tuple_ty->fields.size();
}
}
- size_t return_val_count = 0;
- if (const TensorTypeNode* ttype = ret_type.as<TensorTypeNode>()) {
- // Allocate space for the return tensor.
- auto alloc = AllocTensorFromType(ttype);
- allocs.push_back(alloc);
- return_val_count = 1;
- } else if (const TupleTypeNode* ttype = ret_type.as<TupleTypeNode>()) {
- std::vector<Index> fields_registers;
-
- for (size_t i = 0; i < ttype->fields.size(); ++i) {
- auto f = ttype->fields[i];
- auto f_type = f.as<TensorTypeNode>();
- allocs.push_back(AllocTensorFromType(f_type));
- fields_registers.push_back(allocs.back().dst);
- }
- return_val_count = ttype->fields.size();
- } else {
- LOG(FATAL) << "Unsupported return value type";
- }
-
- arity += return_val_count;
- for (auto& alloc : allocs) {
- Emit(alloc);
- unpacked_arg_regs.push_back(alloc.dst);
+ auto ret_regs = AllocReturnType(ret_type, func, unpacked_arg_regs);
+ size_t return_count = ret_regs.size();
+ arity += return_count;
+ for (auto reg : ret_regs) {
+ unpacked_arg_regs.push_back(reg);
}
// Next generate the invoke instruction.
CHECK_EQ(cfunc->funcs.size(), 1);
auto op_index = -1;
if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) {
- op_index = context_->lowered_funcs.size();
- context_->lowered_funcs.push_back(cfunc->funcs[0]);
+ op_index = context_->cached_funcs.size();
+ context_->cached_funcs.push_back(cfunc);
context_->seen_funcs[cfunc->funcs[0]] = op_index;
} else {
op_index = context_->seen_funcs[cfunc->funcs[0]];
}
- Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs));
+ Emit(Instruction::InvokePacked(op_index, arity, return_count, unpacked_arg_regs));
- if (return_val_count > 1) {
+ if (return_count > 1) {
// return value is a tuple, we need to create a tuple
std::vector<Index> fields_registers;
- for (size_t i = arity - return_val_count; i < arity; ++i) {
+ for (size_t i = arity - return_count; i < arity; ++i) {
fields_registers.push_back(unpacked_arg_regs[i]);
}
- Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister()));
+ Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister()));
}
}
VMCompilerContext* context_;
/*! \brief Target devices. */
TargetsMap targets_;
+ /*! \brief Host target. */
+ Target target_host_;
};
// in the VMFunction table.
PopulateGlobalMap();
- // Next we populate constant map.
- auto constant_analysis_result = LayoutConstantPool(context_.module);
- context_.const_map = std::get<0>(constant_analysis_result);
- context_.const_tensor_shape_map = std::get<1>(constant_analysis_result);
-
// Next we get ready by allocating space for
// the global state.
vm_->functions.resize(context_.module->functions.size());
- vm_->constants.resize(context_.const_map.size() + context_.const_tensor_shape_map.size());
- for (auto pair : context_.const_map) {
- vm_->constants[pair.second] = Object::Tensor(pair.first->data);
- }
-
- for (auto pair : context_.const_tensor_shape_map) {
- vm_->constants[pair.second.first] = Object::Tensor(pair.second.second);
- }
+ // Next we get ready by allocating space for
+ // the global state.
+ vm_->functions.resize(context_.module->functions.size());
for (auto named_func : context_.module->functions) {
auto gvar = named_func.first;
auto func = named_func.second;
- VMFunctionCompiler func_compiler(&context_, targets_);
+ VMFunctionCompiler func_compiler(&context_, targets_, target_host_);
auto vm_func = func_compiler.Compile(gvar, func);
size_t func_index = context_.global_map.at(gvar);
}
#endif // USE_RELAY_DEBUG
+ // populate constants
+ for (auto data : context_.constants) {
+ vm_->constants.push_back(Object::Tensor(data));
+ }
+
LibraryCodegen();
for (auto gv : context_.global_map) {
Module VMCompiler::OptimizeModule(const Module& mod) {
// TODO(@icemelon9): check number of targets and build config, add more optimization pass
transform::Sequential seq({transform::SimplifyInference(),
- transform::ToANormalForm(),
transform::InlinePrimitives(),
+ // TODO(@wweic): FuseOps pass currently don't handle Let
+ // For now, we put FuseOps before ToANormalForm to enable it
+ transform::FuseOps(),
+ transform::ToANormalForm(),
transform::LambdaLift(),
- transform::InlinePrimitives(),
- transform::FuseOps()});
+ transform::InlinePrimitives()});
auto pass_ctx = transform::PassContext::Create();
tvm::With<relay::transform::PassContext> ctx(pass_ctx);
return seq(mod);
}
void VMCompiler::LibraryCodegen() {
- auto const& lowered_funcs = context_.lowered_funcs;
- if (lowered_funcs.size() == 0) {
+ auto const &cached_funcs = context_.cached_funcs;
+ if (cached_funcs.size() == 0) {
return;
}
- // TODO(@icemelon9): support heterogeneous targets
- Target target;
- for (auto kv : targets_) {
- target = kv.second;
+ std::unordered_map<std::string, Array<LoweredFunc>> tgt_funcs;
+ for (auto &cfunc : cached_funcs) {
+ std::string target_str = cfunc->target->str();
+ if (tgt_funcs.count(target_str) == 0) {
+ tgt_funcs.emplace(target_str, Array<LoweredFunc>{cfunc->funcs[0]});
+ } else {
+ tgt_funcs[target_str].push_back(cfunc->funcs[0]);
+ }
}
- if (const auto* f = runtime::Registry::Get("relay.backend.build")) {
- runtime::Module mod =
- (*f)(tvm::Array<LoweredFunc>(lowered_funcs.begin(), lowered_funcs.end()), target,
- target_host_);
+ Map<Target, Array<LoweredFunc>> funcs;
+ for (auto &it : tgt_funcs) {
+ funcs.Set(Target::Create(it.first), it.second);
+ }
+
+ if (const auto *f = runtime::Registry::Get("relay.backend.build")) {
+ // The target is just a dummy arg because funcs already contains corresponding target
+ // therefore target won't be used in the build function
+ runtime::Module mod = (*f)(funcs, Target(), target_host_);
CHECK(mod.operator->());
vm_->lib = mod;
} else {
LOG(FATAL) << "relay.backend.build is not registered";
}
size_t primitive_index = 0;
- for (auto lfunc : lowered_funcs) {
- vm_->primitive_map.insert({lfunc->name, primitive_index++});
+ for (auto cfunc : cached_funcs) {
+ vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++});
}
}
TagMap tag_map;
// Map from global var to a unique integer
GlobalMap global_map;
- // Map from Const object to its index in const pool
- ConstMap const_map;
- // Map from Const tensor shape to its index in const pool
- ConstTensorShapeMap const_tensor_shape_map;
- // List of lowered functions
- std::vector<LoweredFunc> lowered_funcs;
+ // List of constants
+ std::vector<NDArray> constants;
+ // List of cached functions
+ std::vector<CachedFunc> cached_funcs;
// The functions that have been lowered.
std::unordered_map<LoweredFunc, size_t, NodeHash, NodeEqual> seen_funcs;
};
TVM_REGISTER_API("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
TVM_REGISTER_API("relay.op._OpGetAttr")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- Op op = args[0];
- std::string attr_name = args[1];
- auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
- if (op_map.count(op)) {
- *rv = op_map[op];
- }
- });
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ Op op = args[0];
+ std::string attr_name = args[1];
+ auto op_map = Op::GetAttr<TVMRetValue>(attr_name);
+ if (op_map.count(op)) {
+ *rv = op_map[op];
+ }
+ });
+
+TVM_REGISTER_API("relay.op._OpSetAttr")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+ Op op = args[0];
+ std::string attr_name = args[1];
+ runtime::TVMArgValue value = args[2];
+ int plevel = args[3];
+ auto& reg =
+ OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name();
+ reg.set_attr(attr_name, value, plevel);
+ });
TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
used_input_dims.insert(src_idx);
IndexExpr d2 = data_shape[src_idx++];
used_output_dims.insert(oshape.size());
- oshape.push_back(d1 * d2);
+ if (d1.as<Any>() || d2.as<Any>()) {
+ oshape.push_back(Any::make());
+ } else {
+ oshape.push_back(d1 * d2);
+ }
} else if (svalue == -4) {
// split the source dim s into two dims
// read the left dim and then the right dim (either can be -1)
oshape.push_back(d2);
}
}
+ } else {
+ CHECK(false) << "Unsupported special value: " << svalue;
}
}
const Target& target) {
const auto* out_ttype = out_type.as<TensorTypeNode>();
CHECK(out_ttype != nullptr);
- return { topi::reshape(inputs[0], out_ttype->shape) };
+ Array<IndexExpr> newshape;
+ for (auto val : out_ttype->shape) {
+ if (val->is_type<ir::Any>()) {
+ newshape.push_back(val.as<ir::Any>()->ToVar());
+ } else {
+ newshape.push_back(val);
+ }
+ }
+ return { topi::reshape(inputs[0], newshape) };
}
Expr MakeReshape(Expr data,
.set_support_level(3)
.add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective)
+// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
// repeat operator
.add_argument("data", "Tensor", "The input tensor.")
.add_type_rel("ShapeOf", ShapeOfRel)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
-.set_attr<TOpPattern>("TOpPattern", kInjective)
+// Use kOpaque for shape_of op for now since it won't be performance critic,
+// and it makes things easier for dynamic shape func
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ElemwiseArbitraryLayout)
.set_support_level(10)
for (; i <= std::min(ndim1, ndim2); ++i) {
IndexExpr s1 = t1->shape[ndim1 - i];
IndexExpr s2 = t2->shape[ndim2 - i];
- if (EqualCheck(s1, s2)) {
- oshape.push_back(s1);
- } else if (EqualConstInt(s1, 1)) {
+ if (EqualConstInt(s1, 1)) {
oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1);
- } else if (s1.as<Any>() && EqualConstInt(s2, 1)) {
- // TODO(@jroesch): we need to come back to this
+ } else if (s1.as<Any>()) {
+ // s1 == 1 || s1 == s2
oshape.push_back(s2);
- } else if (s2.as<Any>() && EqualConstInt(s1, 1)) {
+ } else if (s2.as<Any>()) {
+ // s2 == 1 || s2 == s1
+ oshape.push_back(s1);
+ } else if (EqualCheck(s1, s2)) {
oshape.push_back(s1);
} else {
RELAY_ERROR(
if (it == gmap_.end()) return "";
std::ostringstream os;
auto *group = it->second->FindRoot();
- os << "group=" << group;
+ os << " /* group=" << group << " */";
return os.str();
});
LOG(INFO) << "Dump of group info:\n" << text;
/*!
* \brief Get the pass information/meta data.
*/
- PassInfo Info() const { return pass_info; }
+ PassInfo Info() const override { return pass_info; }
TVM_DLL static ModulePass make(
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func,
/*!
* \brief Get the pass information/meta data.
*/
- PassInfo Info() const { return pass_info; }
+ PassInfo Info() const override { return pass_info; }
TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func,
/*!
* \brief Get the pass information/meta data.
*/
- PassInfo Info() const { return pass_info; }
+ PassInfo Info() const override { return pass_info; }
/*!
* \brief Check if a pass is enabled.
void InstructionPrint(std::ostream& os, const Instruction& instr) {
switch (instr.op) {
case Opcode::Move: {
- os << "move $" << instr.dst << " $" << instr.from << std::endl;
+ os << "move $" << instr.dst << " $" << instr.from;
break;
}
case Opcode::Ret: {
- os << "ret $" << instr.result << std::endl;
+ os << "ret $" << instr.result;
break;
}
case Opcode::Fatal: {
<< ", out: $"
<< StrJoin<RegName>(instr.packed_args, instr.arity - instr.output_size,
instr.output_size, ", $")
- << ")" << std::endl;
+ << ")";
break;
}
case Opcode::AllocTensor: {
instr.alloc_tensor.ndim)
<< "] ";
DLDatatypePrint(os, instr.alloc_tensor.dtype);
- os << std::endl;
break;
}
case Opcode::AllocTensorReg: {
os << "alloc_tensor_reg $" << instr.dst << " $"
<< instr.alloc_tensor_reg.shape_register << " ";
DLDatatypePrint(os, instr.alloc_tensor_reg.dtype);
- os << std::endl;
break;
}
case Opcode::AllocDatatype: {
os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$"
- << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"
- << std::endl;
+ << StrJoin<RegName>(instr.datatype_fields, 0, instr.num_fields, ",$") << "]";
break;
}
case Opcode::AllocClosure: {
os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index
<< "]($" << StrJoin<RegName>(instr.free_vars, 0, instr.num_freevar, ",$")
- << ")"
- << std::endl;
+ << ")";
break;
}
case Opcode::If: {
os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " "
- << instr.if_op.true_offset << " " << instr.if_op.false_offset
- << std::endl;
+ << instr.if_op.true_offset << " " << instr.if_op.false_offset;
break;
}
case Opcode::Invoke: {
os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($"
<< StrJoin<RegName>(instr.invoke_args_registers, 0, instr.num_args, ",$")
- << ")"
- << std::endl;
+ << ")";
break;
}
case Opcode::InvokeClosure: {
os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($"
<< StrJoin<RegName>(instr.closure_args, 0, instr.num_closure_args, ",$")
- << ")"
- << std::endl;
+ << ")";
break;
}
case Opcode::LoadConst: {
- os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"
- << std::endl;
+ os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]";
break;
}
case Opcode::LoadConsti: {
- os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]"
- << std::endl;
+ os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]";
break;
}
case Opcode::GetField: {
os << "get_field $" << instr.dst << " $" << instr.object << "["
- << instr.field_index << "]"
- << std::endl;
+ << instr.field_index << "]";
break;
}
case Opcode::GetTag: {
- os << "get_tag $" << instr.dst << " $" << instr.get_tag.object << std::endl;
+ os << "get_tag $" << instr.dst << " $" << instr.get_tag.object;
break;
}
case Opcode::Goto: {
- os << "goto " << instr.pc_offset << std::endl;
+ os << "goto " << instr.pc_offset;
break;
}
default:
void VMFunctionPrint(std::ostream& os, const VMFunction& vm_func) {
os << vm_func.name << ": " << std::endl;
for (size_t i = 0; i < vm_func.instructions.size(); ++i) {
- os << i << ": ";
- InstructionPrint(os, vm_func.instructions[i]);
- os << ";" << std::endl;
+ os << i << ": " << vm_func.instructions[i] << ";" << std::endl;
}
}
while (true) {
main_loop:
auto const& instr = this->code[this->pc];
- DLOG(INFO) << "Executing(" << pc << "): ";
+ DLOG(INFO) << "Executing(" << pc << "): " << instr;
#if USE_RELAY_DEBUG
InstructionPrint(std::cout, instr);
#endif // USE_RELAY_DEBUG
std::vector<Array<Expr> > new_body(sch->stages.size());
std::vector<bool> changed(sch->stages.size(), false);
+ std::vector<Stmt> new_hybrid_body(sch->stages.size());
+ std::vector<bool> hybrid_changed(sch->stages.size(), false);
// inline all the ops
for (size_t i = sch->stages.size(); i != 0; --i) {
Stage stage = sch->stages[i - 1];
for (size_t j = i; j < sch->stages.size(); ++j) {
Stage s = sch->stages[j];
const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
+ const HybridOpNode* hybrid = s->op.as<HybridOpNode>();
if (compute) {
if (!new_body[j].size()) {
new_body[j] = compute->body;
}
}
}
+ } else if (hybrid) {
+ if (!new_hybrid_body[j].defined()) {
+ new_hybrid_body[j] = hybrid->body;
+ }
+ Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body);
+ if (!new_stmt.same_as(new_hybrid_body[j])) {
+ new_hybrid_body[j] = new_stmt;
+ hybrid_changed[j] = true;
+ }
}
}
}
}
s->op = op;
}
+ } else if (hybrid_changed[i]) {
+ const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
+ CHECK(hybrid);
+ Operation op = HybridOpNode::make(
+ hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
+ hybrid->outputs, new_hybrid_body[i]);
+ op = op->ReplaceInputs(op, repl);
+ for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
+ repl[s->op.output(idx)] = op.output(idx);
+ }
+ s->op = op;
} else {
Operation op = s->op->ReplaceInputs(s->op, repl);
if (!op.same_as(s->op)) {
import tvm
from tvm import relay
-from tvm.relay import Kind, transform
from tvm.relay.loops import while_loop
from tvm.relay.testing import run_infer_type as infer_type
def int32(val):
return relay.const(val, 'int32')
+def any_dims(ndim):
+ shape = []
+ for _ in range(ndim):
+ shape.append(relay.Any())
+ return tuple(shape)
+
+# TODO(@wweic): because vm doesn't support heterogeneous exec, we can only test
+# shape function on CPU.
+
+def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
+ dtype = 'float32'
+ x = relay.var('x', shape=x_shape, dtype=dtype)
+ y = relay.var('y', shape=y_shape, dtype=dtype)
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x, y], op(x, y))
+ x_np = np.random.uniform(size=x_np_shape).astype(dtype)
+ y_np = np.random.uniform(size=y_np_shape).astype(dtype)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(x_np, y_np)
+ tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np))
+
+def test_any_broadcast():
+ verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add)
+ verify_any_broadcast((relay.Any(), 2), (1, 2), (1, 2), (1, 2), relay.add, np.add)
+ verify_any_broadcast((relay.Any(), 2), (1, 2), (3, 2), (1, 2), relay.add, np.add)
+ verify_any_broadcast((relay.Any(), 2), (3, 2), (1, 2), (3, 2), relay.add, np.add)
+ verify_any_broadcast((relay.Any(), 2), (3, relay.Any()), (1, 2), (3, 1), relay.add, np.add)
+
+ # The following currently fail because topi compute treats Any as 1
+ # will requires auto_broadcast buffer to solve the problem
+ # TODO(@zhiics): Fix this
+ # verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add)
+ # verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add)
+
+def test_any_concat():
+ x = relay.var('x', shape=(relay.Any(), 2), dtype="float32")
+ y = relay.var('y', shape=(1, 2), dtype="float32")
+ z = relay.op.concatenate([x, y], axis=0)
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x, y], z)
+ x_np = np.random.uniform(size=(3, 2)).astype('float32')
+ y_np = np.random.uniform(size=(1, 2)).astype('float32')
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(x_np, y_np)
+ ref = np.concatenate([x_np, y_np], axis=0)
+ tvm.testing.assert_allclose(result.asnumpy(), ref)
+
+def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape):
+ x = relay.var('x', shape=x_shape, dtype="float32")
+ y = relay.reshape(x, newshape=newshape)
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x], y)
+ data = np.random.uniform(size=x_np_shape).astype('float32')
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data).asnumpy()
+ assert result.shape == out_shape
+ tvm.testing.assert_allclose(result.flatten(), data.flatten())
+
+def test_any_reshape():
+ verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24))
+ verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12))
+ verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4))
+ verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4))
+ verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12))
+
+def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape):
+ mod = relay.Module()
+ data = relay.var('data', shape=data_shape, dtype='float32')
+ indices = relay.var('indices', shape=indices_shape, dtype='int32')
+ y = relay.take(data, indices, axis=axis)
+ mod["main"] = relay.Function([data, indices], y)
+ data_np = np.random.uniform(size=data_np_shape).astype('float32')
+ if axis is None:
+ max_index = data_np.size
+ else:
+ max_index = data_np.shape[axis]
+ indices_np = np.random.randint(max_index, size=indices_np_shape).astype('int32')
+ ref = np.take(data_np, indices_np, axis=axis)
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data_np, indices_np)
+ tvm.testing.assert_allclose(result.asnumpy(), ref)
+
+def test_any_take():
+ verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,))
+ verify_any_take(any_dims(2), (), 0, (4, 5), ())
+ verify_any_take(any_dims(2), (), None, (4, 5), ())
+ verify_any_take(any_dims(3), any_dims(2), 1, (3, 4, 5), (2, 3))
+ verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4))
+ verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5))
+
+def test_any_shape_of():
+ x = relay.var('x', shape=any_dims(2), dtype='float32')
+ y = relay.shape_of(x)
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x], y)
+ data = np.random.uniform(size=(3, 4)).astype('float32')
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data)
+ tvm.testing.assert_allclose(result.asnumpy(), np.array([3,4]).astype("int64"))
+
+ x = relay.var('x', shape=any_dims(3), dtype='float32')
+ y0 = relay.shape_of(x)
+ y1 = relay.take(y0, relay.const(1, 'int32'))
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x], y1)
+ data = np.random.uniform(size=(2, 3, 4)).astype('float32')
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data)
+ tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64"))
+
+def test_fused_ops():
+ x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32')
+ y0 = x + relay.const(1.0, 'float32')
+ y1 = y0 * relay.const(2.0, 'float32')
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x], y1)
+ data = np.random.uniform(size=(5, 4)).astype('float32')
+ for kind in ["vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data)
+ tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2)
+
def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32'))
- y2 = relay.op.arange(y1)
- ex = relay.create_executor()
- f = relay.Function([x], y2, type_params=[m, n, k])
- # TODO(@jroesch): Restore after code generation.
- # data = np.random.rand(10, 5, 3).astype('float32')
- # result = ex.evaluate(f)(data)
- # np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
-
-def test_dynamic_concat():
+ y2 = relay.op.arange(y1, dtype="int32")
+ y3 = y2 + relay.const(1, dtype="int32")
+ data = np.random.rand(10, 5, 3).astype('float32')
+ mod = relay.module.Module()
+ mod["main"] = relay.Function([x], y3, type_params=[m, n, k])
+ for kind in ["debug", "vm"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data)
+ tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1)
+
+def test_recursive_concat():
"""
fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
if (%i < 10) {
start = relay.var('start', shape=(), dtype='int32')
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
func = relay.Function([start], relay.TupleGetItem(body, 1))
- func = infer_type(func)
- # TODO(@jroesch, @haichen): We should restore this code when codegeneration
- # is merged
- # ret_shape = func.checked_type.ret_type.shape
- # assert len(ret_shape) == 2, "expected 2-dim output"
- # assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any())
- # import pdb; pdb.set_trace()
- # mod = relay.module.Module()
- # print(relay.ir_pass.infer_type(func, mod=mod))
- # ret = relay.Call(loop, [relay.const(0, 'int32'), init])
- # mod[mod.entry_func] = relay.Function([], ret)
- # print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod))
-
- # initial = np.array(0.0, dtype='float32').reshape((1,))
- # iter_stop = np.array(10, dtype='int32')
- # ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
- # result = ex.evaluate(mod.entry_func)()
- # np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
-
-def test_dynamic_concat_with_wrong_annotation():
+ mod = relay.module.Module()
+ mod["main"] = func
+ data = np.array(0.0, dtype='int32')
+ # TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail
+ # so currently we cannot run this test case on VM
+ for kind in ["debug"]:
+ ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
+ result = ex.evaluate()(data)
+ ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32")
+ np.testing.assert_allclose(result.asnumpy(), ref)
+
+def test_recursive_concat_with_wrong_annotation():
"""
v0.0.1
fn (%start: int32) {
assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
if __name__ == "__main__":
+ test_any_broadcast()
+ test_any_concat()
+ test_any_reshape()
+ test_any_take()
+ test_any_shape_of()
+ test_fused_ops()
test_arange_with_dynamic_shape()
- test_dynamic_concat()
- test_dynamic_concat_with_wrong_annotation()
+ test_recursive_concat()
+ test_recursive_concat_with_wrong_annotation()
vm = create_vm(mod)
ser = serializer.Serializer(vm)
- stats = ser.stats
- assert "scalar" in stats
-
glbs = ser.globals
assert len(glbs) == 3
assert "f1" in glbs
code = ser.bytecode
assert "main 5 2 5" in code
- assert "f1 3 1 4" in code
- assert "f2 3 1 4" in code
+ assert "f1 2 1 3" in code
+ assert "f2 2 1 3" in code
code, lib = ser.serialize()
assert isinstance(code, bytearray)
assert ibody.min.value == 0
assert ibody.extent.name == 'm'
#Check loop body
- jbody = ibody.body
+ jblock = ibody.body
+ assert isinstance(jblock, tvm.stmt.Block)
+ jbody = jblock.first
assert isinstance(jbody, tvm.stmt.AssertStmt)
assert isinstance(jbody.message, tvm.expr.StringImm)
assert jbody.message.value == "index out of range!"
- jbody = jbody.body
+ jbody = jblock.rest
assert isinstance(jbody, tvm.stmt.Provide)
assert jbody.func.name == 'c'
assert len(jbody.args) == 2
tvm::Expr one(1);
int i;
for (i = 1; i <= std::min(s1_size, s2_size); ++i) {
+ // TODO(@icemelon9): Need to revisit this part
+ const Variable* var1 = shape1[s1_size - i].as<Variable>();
+ const Variable* var2 = shape2[s2_size - i].as<Variable>();
bh.all_vars.push_front(tvm::Var());
if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
} else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) {
bh.common_shape.push_front(shape1[s1_size - i]);
bh.vars1.push_front(bh.all_vars[0]);
+ } else if (var1 && var2) {
+ bh.common_shape.push_front(max(shape1[s1_size - i], shape2[s2_size - i]));
+ bh.vars1.push_front(bh.all_vars[0]);
+ bh.vars2.push_front(bh.all_vars[0]);
+ } else if (var1) {
+ bh.common_shape.push_front(shape2[s2_size - i]);
+ bh.vars2.push_front(bh.all_vars[0]);
+ } else if (var2) {
+ bh.common_shape.push_front(shape1[s1_size - i]);
+ bh.vars1.push_front(bh.all_vars[0]);
} else {
CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i]
<< " and " << shape2[s2_size - i] << " in: "
return compute(output_shape, func, name, tag);
}
-inline Tensor arange(const Expr start,
- const Expr stop,
- const Expr step,
+inline Tensor arange(const Expr& start,
+ const Expr& stop,
+ const Expr& step,
Type dtype,
std::string name = "T_arange",
std::string tag = kInjective) {
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
- rtol = 1e-5
- if (kernel > 3):
- rtol = 2e-5
+ rtol = 1e-3
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)