From: Haichen Shen Date: Sun, 1 Sep 2019 01:50:22 +0000 (-0700) Subject: [Relay][Any] Add shape func for dynamic shape (#3606) X-Git-Tag: upstream/0.7.0~1976 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=eef35a57d95c650e490b168f1f585d9ec00412ee;p=platform%2Fupstream%2Ftvm.git [Relay][Any] Add shape func for dynamic shape (#3606) * init shape func in interpreter and vm compiler * Update interpreter * fix * lint * lint * fix * remove hack * update * fix * fix * update * address comments & update for shape_of * fix lint * update * fix hybrid * lint * fix bug & add take shape func * lint * lint * update * fix flaky test * add todo --- diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 03559d3..a9c6c4b 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -704,6 +704,10 @@ class Reduce : public ExprNode { class Any : public ExprNode { public: void VisitAttrs(AttrVisitor* v) final {} + /*! \brief Convert to var. */ + Var ToVar() const { + return Variable::make(Int(32), "any_dim"); + } TVM_DLL static Expr make(); diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index c1a0f83..741e8b4 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -75,6 +75,11 @@ using TOpIsStateful = bool; using TNonComputational = bool; /*! + * \brief Mark the operator whether output shape is data dependant. + */ +using TShapeDataDependant = bool; + +/*! * \brief Computation description interface. * * \note This function have a special convention @@ -186,7 +191,7 @@ using Shape = Array; using FShapeFunc = runtime::TypedPackedFunc< Array(const Attrs& attrs, const Array& inputs, - const Array& out_shapes)>; + const Array& out_ndims)>; } // namespace relay } // namespace tvm diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 40ea171..171a2f8 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -25,7 +25,7 @@ import numbers from enum import Enum -from .util import _internal_assert, _apply_indices +from .util import _internal_assert from . import calls from . import util from .preprocessor import determine_variable_usage @@ -35,7 +35,6 @@ from ..container import Array from ..tensor import Tensor, Operation from .. import _api_internal as _tvm_internal from .. import expr as _expr -from .. import stmt as _stmt from .. import make as _make from .. import api as _api from .. import ir_pass as _ir_pass @@ -43,16 +42,15 @@ from .. import ir_pass as _ir_pass def concat_list_to_block(lst): """Concatenate a list of Python IR nodes to HalideIR Block""" + if not lst: + return util.make_nop() n = len(lst) if n == 1: return lst[0] body = lst[n - 1] for i in range(1, n): stmt = lst[n - 1 - i] - if isinstance(stmt, _stmt.AssertStmt): - body = _make.AssertStmt(stmt.condition, stmt.message, body) - else: - body = _make.Block(stmt, body) + body = _make.Block(stmt, body) return body @@ -100,8 +98,8 @@ class HybridParser(ast.NodeVisitor): ast.LtE : operator.le, ast.Eq : operator.eq, ast.NotEq : operator.ne, - ast.And : _all, - ast.Or : _any, + ast.And : _all, + ast.Or : _any, } @@ -179,6 +177,9 @@ class HybridParser(ast.NodeVisitor): to_pop = [] for key, val in self.usage.items(): _, level, _ = val + if key not in self.symbols: + # don't realize the symbols that are never visited + continue if level != node: continue _internal_assert(key in self.symbols.keys(), "Unknown symbol %s!" % key) @@ -363,44 +364,25 @@ class HybridParser(ast.NodeVisitor): def visit_Attribute(self, node): - _internal_assert(isinstance(node.value, ast.Name), \ - "For atrribute access, only both names are supported so far!") buf = self.visit(node.value) return getattr(buf, node.attr) def visit_Subscript(self, node): args = self.visit(node.slice) - if isinstance(node.value, ast.Name): - if node.value.id in self.closure_vars: - args = ast.literal_eval(str(args)) - return _api.convert(_apply_indices(self.closure_vars[node.value.id], args)) - - buf = self.visit(node.value) - if isinstance(buf, Array): - for i in args: - if isinstance(i, numbers.Integral): - buf = buf[i] - else: - _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ - "All indices are supposed to be constants") - buf = buf[i.value] - - return buf - - if isinstance(node.ctx, ast.Load): - return _make.Call(buf.dtype, buf.name, args, \ - _expr.Call.Halide, buf.op, buf.value_index) - - return buf, args - - shape = self.visit(node.value) - _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!") - args = args[0] - #TODO: maybe support non-constant value later? - _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \ - "So far only constant shape access supported!") - return shape[args.value] - + arr = self.visit(node.value) + if isinstance(arr, Array): + for i in args: + if isinstance(i, numbers.Integral): + arr = arr[i] + else: + _internal_assert(isinstance(i, (_expr.IntImm, _expr.UIntImm)), \ + "All indices are supposed to be constants") + arr = arr[i.value] + return arr + if isinstance(node.ctx, ast.Load): + return _make.Call(arr.dtype, arr.name, args, + _expr.Call.Halide, arr.op, arr.value_index) + return arr, args def visit_With(self, node): if sys.version_info[0] < 3: @@ -417,7 +399,7 @@ class HybridParser(ast.NodeVisitor): def visit_If(self, node): - cond = self.visit(node.test) + cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) # Return no IfThenElse if proven if isinstance(cond, _expr.UIntImm): @@ -508,11 +490,11 @@ class HybridParser(ast.NodeVisitor): _name = node.target.id if isinstance(for_type, tuple): - low = _ir_pass.Simplify(low) - ext = _ir_pass.Simplify(ext) + low = _ir_pass.CanonicalSimplify(low) + ext = _ir_pass.CanonicalSimplify(ext) _internal_assert(isinstance(low, _expr.ConstExpr) and isinstance(ext, _expr.ConstExpr), \ - "Const range should start from a const" + \ + "Const range should start from a const " + \ "and iterate const times") low, ext = low.value, ext.value diff --git a/python/tvm/hybrid/util.py b/python/tvm/hybrid/util.py index 058c5aa..0dd1fa1 100644 --- a/python/tvm/hybrid/util.py +++ b/python/tvm/hybrid/util.py @@ -101,9 +101,3 @@ def _is_tvm_arg_types(args): _internal_assert(isinstance(elem, np_arg_types), \ "Expect a numpy type but %s get!" % str(type(elem))) return False - -def _apply_indices(value, indices): - """Apply multidimensional index""" - if indices: - return _apply_indices(value[indices[0]], indices[1:]) - return value diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index d4cc68b..745ef63 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -177,6 +177,10 @@ class VMCompiler(object): The VM runtime. """ target = _update_target(target) + target_host = None if target_host == "" else target_host + if not target_host: + target_host = "llvm" if tvm.module.enabled("llvm") else "stackvm" + target_host = tvm.target.create(target_host) self._compile(mod, target, target_host) return VirtualMachine(self._get_vm()) diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 176def3..2e34233 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -14,12 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -#pylint: disable=invalid-name, unused-argument +#pylint: disable=invalid-name, unused-argument, len-as-condition """Backend compiler related feature registration""" from __future__ import absolute_import import topi -from .op import register_compute, register_schedule, register_pattern +from .op import register_compute, register_schedule, register_pattern, register_shape_func from .op import schedule_injective, OpPattern +from ...hybrid import script schedule_broadcast = schedule_injective schedule_elemwise = schedule_injective @@ -104,3 +105,49 @@ def clip_compute(attrs, inputs, output_type, target): return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)] register_schedule("clip", schedule_elemwise) + +# shape func +@script +def _broadcast_shape_func(x, y, ndim): + out = output_tensor((ndim,), "int64") + if len(x.shape) == 0: + for i in const_range(ndim): + out[i] = y[i] + elif len(y.shape) == 0: + for i in const_range(ndim): + out[i] = x[i] + else: + ndim1 = x.shape[0] + ndim2 = y.shape[0] + for i in const_range(1, min(ndim1, ndim2)+1): + if x[ndim1-i] == y[ndim2-i]: + out[ndim-i] = x[ndim1-i] + elif x[ndim1-i] == 1: + out[ndim-i] = y[ndim2-i] + else: + assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % ( + x[ndim1-i], y[ndim2-i]) + out[ndim-i] = x[ndim1-i] + for i in const_range(min(ndim1, ndim2)+1, ndim+1): + if ndim1 >= ndim2: + out[ndim-i] = x[ndim1-i] + else: + out[ndim-i] = y[ndim2-i] + return out + +def broadcast_shape_func(attrs, inputs, out_ndims): + return [_broadcast_shape_func(*inputs, out_ndims[0])] + +register_shape_func("add", False, broadcast_shape_func) +register_shape_func("subtract", False, broadcast_shape_func) +register_shape_func("multiply", False, broadcast_shape_func) +register_shape_func("divide", False, broadcast_shape_func) +register_shape_func("mod", False, broadcast_shape_func) +register_shape_func("logical_and", False, broadcast_shape_func) +register_shape_func("logical_or", False, broadcast_shape_func) +register_shape_func("equal", False, broadcast_shape_func) +register_shape_func("not_equal", False, broadcast_shape_func) +register_shape_func("less", False, broadcast_shape_func) +register_shape_func("less_equal", False, broadcast_shape_func) +register_shape_func("greater", False, broadcast_shape_func) +register_shape_func("greater_equal", False, broadcast_shape_func) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index a4c9375..7f29e85 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. """Backend compiler related feature registration""" -# pylint: disable=invalid-name,unused-argument +# pylint: disable=invalid-name,unused-argument, len-as-condition from __future__ import absolute_import +from topi.util import get_const_int, get_const_tuple from . import op as _reg from ._reduce import _schedule_reduce from .op import OpPattern +from ...hybrid import script +from ...api import convert schedule_injective = _reg.schedule_injective schedule_broadcast = _reg.schedule_injective @@ -58,3 +61,145 @@ _reg.register_schedule("one_hot", schedule_injective) # layout_transform _reg.register_schedule("layout_transform", schedule_injective) _reg.register_pattern("layout_transform", OpPattern.INJECTIVE) + +# shape func +@script +def _arange_shape_func(start, stop, step): + out = output_tensor((1,), "int64") + out[0] = int64(ceil_div((float32(stop[0]) - float32(start[0])), float32(step[0]))) + return out + +@_reg.register_shape_func("arange", True) +def arange_shape_func(attrs, inputs, _): + return [_arange_shape_func(*inputs)] + +@script +def _concatenate_shape_func(inputs, axis): + ndim = inputs[0].shape[0] + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + if i != axis: + out[i] = inputs[0][i] + for j in const_range(1, len(inputs)): + assert out[i] == inputs[j][i], \ + "Dims mismatch in the inputs of concatenate." + else: + out[i] = int64(0) + for j in const_range(len(inputs)): + out[i] += inputs[j][i] + return out + +@_reg.register_shape_func("concatenate", False) +def concatenate_shape_func(attrs, inputs, _): + axis = get_const_int(attrs.axis) + return [_concatenate_shape_func(inputs, convert(axis))] + +@script +def _reshape_shape_func(data_shape, newshape, ndim): + out = output_tensor((ndim,), "int64") + src_idx = 0 + dst_idx = 0 + infer_idx = -1 + copy = False + skip = 0 + for i in const_range(len(newshape)): + if skip > 0: + skip -= 1 + elif newshape[i] > 0: + out[dst_idx] = int64(newshape[i]) + src_idx += 1 + dst_idx += 1 + elif newshape[i] == 0: + out[dst_idx] = data_shape[src_idx] + src_idx += 1 + dst_idx += 1 + elif newshape[i] == -1: + assert infer_idx < 0, "One and only one dim can be inferred" + out[dst_idx] = int64(1) + infer_idx = i + dst_idx += 1 + elif newshape[i] == -2: + copy = True + elif newshape[i] == -3: + assert data_shape.shape[0] - src_idx > 1, \ + "Not enough dims in input shape for -3" + out[dst_idx] = data_shape[src_idx] * data_shape[src_idx+1] + src_idx += 2 + dst_idx += 1 + elif newshape[i] == -4: + assert len(newshape) - i > 2, "Not enough dims in new shape for -4" + if newshape[i+1] == -1: + assert newshape[i+2] != -1, "Split dims cannot both be -1." + out[dst_idx] = data_shape[src_idx] / int64(newshape[i+2]) + out[dst_idx+1] = int64(newshape[i+2]) + else: + out[dst_idx] = int64(newshape[i+1]) + if newshape[i+2] == -1: + out[dst_idx+1] = data_shape[src_idx] / int64(newshape[i+1]) + else: + out[dst_idx+1] = int64(newshape[i+2]) + assert data_shape[src_idx] == out[dst_idx] * out[dst_idx+1],\ + "Product of split dims doesn't match to input dim" + src_idx += 1 + dst_idx += 2 + skip = 2 + else: + assert False, "Invalid special values in new shape" + if len(data_shape.shape) > 0: + # if data is not constant, we can then handle -1 and -2 + if copy: + for i in range(src_idx, data_shape.shape[0]): + out[dst_idx] = data_shape[i] + dst_idx += 1 + if infer_idx >= 0: + old_size = int64(1) + for i in const_range(data_shape.shape[0]): + old_size *= data_shape[i] + new_size = int64(1) + for i in const_range(out.shape[0]): + new_size *= out[i] + out[infer_idx] = old_size / new_size + return out + +@_reg.register_shape_func("reshape", False) +def reshape_shape_func(attrs, inputs, out_ndims): + newshape = get_const_tuple(attrs.newshape) + return [_reshape_shape_func(inputs[0], convert(newshape), out_ndims[0])] + +@script +def _take_no_axis_shape_func(indices_shape, out_ndim): + out = output_tensor((out_ndim,), "int64") + for i in const_range(out_ndim): + out[i] = indices_shape[i] + return out + +@script +def _take_with_axis_shape_func(data_shape, indices_shape, axis, out_ndim): + out = output_tensor((out_ndim,), "int64") + for i in const_range(axis): + out[i] = data_shape[i] + if len(indices_shape.shape) == 0: + # indices is constant + for i in const_range(axis+1, len(data_shape)): + out[i-1] = data_shape[i] + else: + for i in const_range(len(indices_shape)): + out[axis+i] = indices_shape[i] + for i in const_range(axis+1, len(data_shape)): + out[len(indices_shape)+i-1] = data_shape[i] + return out + +@_reg.register_shape_func("take", False) +def take_shape_func(attrs, inputs, out_ndims): + """ + Shape function for take op. + """ + if attrs.axis is None: + return [_take_no_axis_shape_func(inputs[1], out_ndims[0])] + else: + axis = get_const_int(attrs.axis) + data_ndim = int(inputs[0].shape[0]) + if axis < 0: + axis += data_ndim + assert 0 <= axis < data_ndim + return [_take_with_axis_shape_func(*inputs, convert(axis), out_ndims[0])] diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index e07d153..fcbc3fd 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -48,6 +48,22 @@ class Op(Expr): """ return _OpGetAttr(self, attr_name) + def set_attr(self, attr_name, value, plevel=10): + """Set attribute about the operator. + + Parameters + ---------- + attr_name : str + The attribute name + + value : object + The attribute value + + plevel : int + The priority level + """ + _OpSetAttr(self, attr_name, value, plevel) + def get(op_name): """Get the Op for a given name @@ -219,6 +235,26 @@ def register_gradient(op_name, fgradient=None, level=10): """ return register(op_name, "FPrimalGradient", fgradient, level) +def register_shape_func(op_name, data_dependant, shape_func=None, level=10): + """Register operator shape function for an op. + + Parameters + ---------- + op_name : str + The name of the op. + + data_dependant : bool + Whether the shape function depends on input data. + + shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr]) + -> shape_tensors: List + The function for computing the dynamic output shapes + + level : int + The priority level + """ + get(op_name).set_attr("TShapeDataDependant", data_dependant, level) + return register(op_name, "FShapeFunc", shape_func, level) _init_api("relay.op", __name__) diff --git a/src/op/hybrid_op.cc b/src/op/hybrid_op.cc index f9814e8..934e894 100644 --- a/src/op/hybrid_op.cc +++ b/src/op/hybrid_op.cc @@ -82,7 +82,25 @@ Operation HybridOpNode::make(std::string name, } Array HybridOpNode::InputTensors() const { - return inputs; + // Because input tensors could be potentially inlined into hybrid scripts, + // we need to check if all input tensors are used in the body. + std::unordered_set orig_inputs; + for (auto t : inputs) { + orig_inputs.insert(t); + } + std::unordered_set visited; + Array curr_inputs; + ir::PostOrderVisit(body, [&curr_inputs, &orig_inputs, &visited](const NodeRef& n) { + const ir::Call *call = n.as(); + if (call != nullptr && call->func.defined()) { + Tensor t = Operation(call->func.node_).output(call->value_index); + if (orig_inputs.count(t) && !visited.count(t)) { + curr_inputs.push_back(t); + visited.insert(t); + } + } + }); + return curr_inputs; } Operation HybridOpNode::ReplaceInputs( @@ -111,7 +129,8 @@ void HybridOpNode::PropBoundToInputs( arith::Analyzer* analyzer, const std::unordered_map &dom_map, std::unordered_map* out_dom_map) const { - for (Tensor t : this->inputs) { + auto curr_inputs = InputTensors(); + for (Tensor t : curr_inputs) { auto it = out_dom_map->find(t); if (it == out_dom_map->end()) continue; TensorDom &dom = it->second; @@ -180,11 +199,10 @@ Stmt HybridOpNode::BuildProvide( outputs[i]->dtype); f_push_bind(buffer, stage->op.output(i)); } - for (int i = static_cast(inputs.size()) - 1; i >= 0; --i) { - Buffer buffer = decl_buffer( - inputs[i]->shape, - inputs[i]->dtype); - f_push_bind(buffer, inputs[i]); + auto curr_inputs = InputTensors(); + for (int i = static_cast(curr_inputs.size()) - 1; i >= 0; --i) { + Buffer buffer = decl_buffer(curr_inputs[i]->shape, curr_inputs[i]->dtype); + f_push_bind(buffer, curr_inputs[i]); } std::unordered_map rmap; @@ -203,7 +221,7 @@ Stmt HybridOpNode::BuildProvide( * tensors have the same names as the operation produces them. * 2. Once OpNode is wrapped up by an Operation node, it is finalized. * Later access will be from a const OpNode*. - * This is a chiken-egg paradox. It is impossible to put the output + * This is a chicken-egg paradox. It is impossible to put the output * tensors into the function body without forming the op node. The * function body is immutable after the node is formed. * diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index eba1cee..e118202 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -41,7 +41,7 @@ namespace ir { using runtime::StorageRank; using runtime::StorageScope; -// Find a linear pattern of storage acess +// Find a linear pattern of storage access // Used for liveness analysis. // Composite scopes(loop/thread_launch/IfThen) is represented by two points: // before_scope -> scope_body -> after_scope @@ -193,6 +193,10 @@ class LinearAccessPatternFinder final : public IRVisitor { VisitNewScope(op); } + void Visit_(const AssertStmt* op) final { + VisitNewScope(op); + } + // linearized access sequence. std::vector linear_seq_; // The storage scope of each buffer diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ab90631..c88703e 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -35,7 +35,9 @@ #include #include #include +#include #include +#include "../ir/type_functor.h" #include "compile_engine.h" namespace tvm { @@ -48,6 +50,43 @@ CCacheKey CCacheKeyNode::make(Function source_func, Target target) { return CCacheKey(n); } +struct IsDynamicVisitor : public TypeVisitor { + bool is_dyn{false}; + void VisitType_(const TensorTypeNode* tt) { + for (auto dim : tt->shape) { + if (dim.as()) { + is_dyn = true; + break; + } + } + } +}; + +bool IsDynamic(const Type& ty) { + IsDynamicVisitor v; + v.VisitType(ty); + return v.is_dyn; +} + +Array GetShape(const Array& shape) { + // for now, we always use int32 shape when possible + // even if the result of shape inference becomes int64. + Array res; + for (IndexExpr val : shape) { + const int64_t* pval = as_const_int(val); + if (pval != nullptr) { + CHECK_LE(pval[0], std::numeric_limits::max()); + CHECK_GE(pval[0], std::numeric_limits::min()); + res.push_back(ir::IntImm::make(Int(32), *pval)); + } else if (val->is_type()) { + res.push_back(val.as()->ToVar()); + } else { + res.push_back(val); + } + } + return res; +} + // The getter to get schedule from compile engine. // Get schedule from functor. class ScheduleGetter : @@ -56,23 +95,6 @@ class ScheduleGetter : explicit ScheduleGetter(Target target) : target_(target) {} - Array GetShape(const Array& shape) { - // for now, we always use int32 shape when possible - // even if the result of shape inference becomes int64. - Array res; - for (IndexExpr val : shape) { - const int64_t* pval = as_const_int(val); - if (pval != nullptr) { - CHECK_LE(pval[0], std::numeric_limits::max()); - CHECK_GE(pval[0], std::numeric_limits::min()); - res.push_back(ir::IntImm::make(Int(32), *pval)); - } else { - res.push_back(val); - } - } - return res; - } - std::pair Create(const Function& prim_func) { static auto fschedule = Op::GetAttr("FTVMSchedule"); @@ -90,6 +112,7 @@ class ScheduleGetter : const auto* tuple_type = param->type_as(); for (Type field : tuple_type->fields) { const auto* ttype = field.as(); + // TODO(@icemelon): Allow recursive tuple CHECK(ttype != nullptr); tvm::Tensor tensor = tvm::placeholder( GetShape(ttype->shape), ttype->dtype); @@ -283,6 +306,255 @@ class ScheduleGetter : Array scalars_; }; +// Creates shape function from functor. +class MakeShapeFunc : public ExprFunctor(const Expr&)> { + public: + MakeShapeFunc() {} + + std::pair Create(const Function& prim_func) { + for (auto param : prim_func->params) { + param_states_[param] = kNoNeed; + Array data_inputs; + Array shape_inputs; + + auto add_placeholder = [&data_inputs, &shape_inputs](const TensorTypeNode* ttype) { + // Add data placeholder + Shape shape = GetShape(ttype->shape); + tvm::Tensor data_tensor = tvm::placeholder(shape, ttype->dtype); + data_inputs.push_back(data_tensor); + // Add shape placeholder + int64_t ndim = shape.size(); + Shape sshape; + if (ndim > 0) { + sshape.push_back(tvm::Integer(ndim)); + } + tvm::Tensor shape_tensor = tvm::placeholder(sshape, Int(64)); + shape_inputs.push_back(shape_tensor); + }; + + if (const auto *ttype = param->checked_type().as()) { + add_placeholder(ttype); + } else { + // flatten tuple of tensor type. + const auto *tuple_type = param->type_as(); + // TODO(@icemelon): Support recursive tuple + CHECK(tuple_type); + for (Type field : tuple_type->fields) { + const auto *ttype = field.as(); + CHECK(ttype); + add_placeholder(ttype); + } + } + param_data_[param] = data_inputs; + param_shapes_[param] = shape_inputs; + } + readable_name_stream_ << "shape_func"; + auto cache_node = make_node(); + cache_node->outputs = VisitExpr(prim_func->body); + auto candidate_name = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + cache_node->func_name = candidate_name; + + // set inputs + for (auto param : prim_func->params) { + int state = param_states_[param]; + cache_node->shape_func_param_states.push_back(IntImm::make(Int(32), state)); + if (state & kNeedInputData) { + for (auto t : param_data_[param]) { + cache_node->inputs.push_back(t); + } + } + if (state & kNeedInputShape) { + for (auto t : param_shapes_[param]) { + cache_node->inputs.push_back(t); + } + } + } + + CachedFunc cfunc(cache_node); + // generate schedule for shape func + Array out_ops; + for (auto t : cache_node->outputs) { + out_ops.push_back(t->op); + } + auto schedule = create_schedule(out_ops); + tvm::schedule::AutoInlineInjective(schedule); + for (const auto& scalar : scalars_) { + auto scalar_op = scalar->op; + if (schedule->Contain(scalar_op)) { + schedule[scalar_op].compute_inline(); + } + } + return std::make_pair(schedule, cfunc); + } + + Array VisitExpr(const Expr& expr) { + auto it = memo_.find(expr); + if (it != memo_.end()) { + return it->second; + } else { + Array res = ExprFunctor::VisitExpr(expr); + if (expr.as() == nullptr) { + // Do not memoize vars because shape functions could use either the data + // or the shape of a var each time. + memo_[expr] = res; + } + return res; + } + } + + Array VisitExpr_(const VarNode* var_node) final { + auto var = GetRef(var_node); + auto it = param_states_.find(var); + if (it == param_states_.end()) { + LOG(FATAL) << "Free variable " << var->name_hint(); + return {}; + } else { + CHECK(data_dependants_.size()); + bool data_dependant = data_dependants_.back(); + if (data_dependant) { + param_states_[var] |= kNeedInputData; + return param_data_[var]; + } else { + param_states_[var] |= kNeedInputShape; + return param_shapes_[var]; + } + } + } + + Array VisitExpr_(const ConstantNode* op) final { + CHECK(data_dependants_.size()); + CHECK(op->is_scalar()); + bool data_dependant = data_dependants_.back(); + if (data_dependant) { + void* data = op->data->data; + DataType dtype = TVMType2Type(op->data->dtype); + Tensor value = tvm::compute({}, [&](const Array&) { + if (dtype == Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::Expr(); + } + }, "data_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } else { + Tensor value = tvm::compute({}, [&](const Array&) { + return make_const(Int(64), 0); + }, "shape_const", topi::kBroadcast); + scalars_.push_back(value); + return {value}; + } + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto fshape_func = Op::GetAttr("FShapeFunc"); + static auto tshape_data_dependant = Op::GetAttr( + "TShapeDataDependant"); + CHECK(call_node->op.as()) + << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + CHECK(data_dependants_.empty() || !data_dependants_.back()) + << "Error in op fusion: output of the shape func is fed to a " + << "data-dependant shape func"; + CHECK_GT(fshape_func.count(op), 0) + << "Internal error, cannot find ShapeFunc for " << op->name; + CHECK_GT(tshape_data_dependant.count(op), 0) + << "Internal error, cannot find TShapeDataDependant for " << op->name; + + data_dependants_.push_back(tshape_data_dependant[op]); + // Visit all inputs + Array inputs; + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + for (Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + } + if (count_tuple) { + CHECK_EQ(call_node->args.size(), 1U) + << "Only allow function with a single tuple input"; + } + // Get output ndims + auto ret_type = call_node->checked_type(); + Array out_ndims; + if (const auto* ttype = ret_type.as()) { + out_ndims.push_back(IntImm::make(Int(32), ttype->shape.size())); + } else { + auto rtype = ret_type.as(); + // TODO(@icemelon): Allow recursive tuple + CHECK(rtype); + for (size_t i = 0; i < rtype->fields.size(); ++i) { + auto ttype = rtype->fields[i].as(); + CHECK(ttype); + out_ndims.push_back(IntImm::make(Int(32), ttype->shape.size())); + } + } + // Call shape function + auto outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); + data_dependants_.pop_back(); + readable_name_stream_ << "_" << op->name; + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + CHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + CHECK(field->checked_type().as()) + << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + CHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + private: + /*! \brief String stream for function name */ + std::ostringstream readable_name_stream_; + /*! \brief Map from parameter to its shape function usage state */ + std::unordered_map param_states_; + /*! \brief Map from parameter to list of data placeholder */ + std::unordered_map, NodeHash, NodeEqual> param_data_; + /*! \brief Map from parameter to list of shape placeholder */ + std::unordered_map, NodeHash, NodeEqual> param_shapes_; + /*! \brief Memoized visit result */ + std::unordered_map, NodeHash, NodeEqual> memo_; + /*! \brief Stack of data dependencies for shape function */ + std::vector data_dependants_; + /*! \brief Scalars used in the shape function */ + Array scalars_; +}; class CompileEngineImpl : public CompileEngineNode { public: @@ -304,6 +576,11 @@ class CompileEngineImpl : public CompileEngineNode { } return value->packed_func; } + + CachedFunc LowerShapeFunc(const CCacheKey& key) final { + return LowerShapeFuncInternal(key)->cached_func; + } + void Clear() final { cache_.clear(); } @@ -379,6 +656,40 @@ class CompileEngineImpl : public CompileEngineNode { value->cached_func = CachedFunc(cache_node); return value; } + // implement lowered shape func + CCacheValue LowerShapeFuncInternal(const CCacheKey& key) { + std::lock_guard lock(mutex_); + CCacheValue value; + auto it = shape_func_cache_.find(key); + if (it != shape_func_cache_.end()) { + it->second->use_count += 1; + if (it->second->cached_func.defined()) return it->second; + value = it->second; + } else { + value = CCacheValue(make_node()); + value->use_count = 0; + shape_func_cache_[key] = value; + } + // Enforce use the target. + With target_scope(key->target); + + CHECK(!value->cached_func.defined()); + auto spair = MakeShapeFunc().Create(key->source_func); + auto cache_node = make_node( + *(spair.second.operator->())); + cache_node->func_name = GetUniqueName(cache_node->func_name); + cache_node->target = key->target; + + Array all_args = cache_node->inputs; + for (Tensor arg : cache_node->outputs) { + all_args.push_back(arg); + } + tvm::BuildConfig bcfg = BuildConfig::Create(); + std::unordered_map binds; + cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg); + value->cached_func = CachedFunc(cache_node); + return value; + } /*! * \brief Get unique name from name. * \param name The orginal name. @@ -408,6 +719,8 @@ class CompileEngineImpl : public CompileEngineNode { std::unordered_map name_map_; /*! \brief internal compiler cache */ std::unordered_map cache_; + /*! \brief internal compiler cache for shape funcs */ + std::unordered_map shape_func_cache_; }; /*! \brief The global compile engine */ diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 9765cf9..84197db 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -36,6 +36,14 @@ namespace tvm { namespace relay { +/*! \brief Indicate whether the data or shape or both of a parameter is used in the shape func. */ +enum ShapeFuncParamState { + kNoNeed = 0, + kNeedInputData = 1, + kNeedInputShape = 2, + kNeedBoth = 3, +}; + /*! \brief Node container to represent a cached function. */ struct CachedFuncNode : public Node { /* \brief compiled target */ @@ -48,6 +56,8 @@ struct CachedFuncNode : public Node { tvm::Array outputs; /*! \brief The lowered functions to support the function. */ tvm::Array funcs; + /*! \brief Parameter usage states in the shape function. */ + tvm::Array shape_func_param_states; void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("target", &target); @@ -55,6 +65,7 @@ struct CachedFuncNode : public Node { v->Visit("inputs", &inputs); v->Visit("outputs", &outputs); v->Visit("funcs", &funcs); + v->Visit("shape_func_param_states", &shape_func_param_states); } static constexpr const char* _type_key = "relay.CachedFunc"; @@ -170,6 +181,12 @@ class CompileEngineNode : public Node { * \return The result. */ virtual PackedFunc JIT(const CCacheKey& key) = 0; + /*! + * \brief Lower the shape function. + * \param key The key to the cached function. + * \return The result. + */ + virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0; /*! \brief clear the cache. */ virtual void Clear() = 0; @@ -180,7 +197,7 @@ class CompileEngineNode : public Node { TVM_DECLARE_NODE_TYPE_INFO(CompileEngineNode, Node); }; -/*! \brier cache entry used in compile engine */ +/*! \brief cache entry used in compile engine */ class CompileEngine : public NodeRef { public: CompileEngine() {} @@ -193,6 +210,13 @@ class CompileEngine : public NodeRef { TVM_DLL static const CompileEngine& Global(); }; +/*! + * \brief Check if the type is dynamic. + * \param ty The type to be checked. + * \return The result. + */ +bool IsDynamic(const Type& ty); + // implementations inline size_t CCacheKeyNode::Hash() const { if (hash_ != 0) return hash_; diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 913d7ad..dedff7a 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -280,7 +280,7 @@ class Interpreter : return TupleValueNode::make(values); } - // TODO(@jroesch): this doesn't support mututal letrec + // TODO(@jroesch): this doesn't support mutual letrec inline Value MakeClosure(const Function& func, Var letrec_name = Var()) { tvm::Map captured_mod; Array free_vars = FreeVars(func); @@ -310,7 +310,125 @@ class Interpreter : return MakeClosure(func); } - Value InvokePrimitiveOp(Function func, + Array ComputeDynamicShape(const Function& func, + const Array& args) { + auto key = CCacheKeyNode::make(func, Target::Create("llvm")); + auto cfunc = engine_->LowerShapeFunc(key); + size_t arity = cfunc->inputs.size() + cfunc->outputs.size(); + + std::vector values(arity); + std::vector codes(arity); + TVMArgsSetter setter(values.data(), codes.data()); + std::vector inputs(cfunc->inputs.size()); + std::vector outputs(cfunc->outputs.size()); + + DLContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + + auto fset_input = [&](size_t i, Value val, bool need_shape) { + const TensorValueNode* tv = val.as(); + CHECK(tv != nullptr) << "expect Tensor argument"; + if (need_shape) { + int64_t ndim = tv->data.Shape().size(); + NDArray shape_arr; + if (ndim == 0) { + shape_arr = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx); + } else { + shape_arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx); + int64_t* data = reinterpret_cast(shape_arr->data); + for (auto j = 0; j < ndim; ++j) { + data[j] = tv->data.Shape()[j]; + } + } + inputs[i] = shape_arr; + setter(i, shape_arr); + } else { + auto arr = tv->data.CopyTo(cpu_ctx); + inputs[i] = arr; + setter(i, arr); + } + }; + + size_t arg_counter = 0; + for (size_t i = 0; i < args.size(); ++i) { + auto arg = args[i]; + auto param = func->params[i]; + int state = cfunc->shape_func_param_states[i]->value; + if (arg.as()) { + if (state & kNeedInputData) { + fset_input(arg_counter++, arg, false); + } + if (state & kNeedInputShape) { + fset_input(arg_counter++, arg, true); + } + } else { + const TupleValueNode* tuple = arg.as(); + CHECK(tuple != nullptr); + if (state & kNeedInputData) { + for (size_t i = 0; i < tuple->fields.size(); ++i) { + fset_input(arg_counter++, tuple->fields[i], false); + } + } + if (state & kNeedInputShape) { + for (size_t i = 0; i < tuple->fields.size(); ++i) { + fset_input(arg_counter++, tuple->fields[i], true); + } + } + } + } + CHECK_EQ(arg_counter, cfunc->inputs.size()) + << "Shape function input sizes mismatch"; + + auto fset_shape_output = [&](size_t i, Type val_type) { + // TODO(@icemelon): allow recursive tuple + const TensorTypeNode* rtype = val_type.as(); + CHECK(rtype != nullptr); + int64_t ndim = rtype->shape.size(); + auto arr = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx); + outputs[i] = arr; + setter(arg_counter + i, arr); + }; + + auto ret_type = func->body->checked_type(); + size_t out_cnt = 0; + if (auto rtype = ret_type.as()) { + out_cnt = rtype->fields.size(); + for (size_t i = 0; i < out_cnt; ++i) { + fset_shape_output(i, rtype->fields[i]); + } + } else { + out_cnt = 1; + auto tt = Downcast(ret_type); + fset_shape_output(0, tt); + } + CHECK_EQ(cfunc->outputs.size(), out_cnt) + << "Shape function output sizes mismatch"; + + PackedFunc shape_func; + TVMRetValue rv; + if (const auto* f = runtime::Registry::Get("relay.backend.build")) { + tvm::runtime::Module m = (*f)(cfunc->funcs, cfunc->target); + shape_func = m.GetFunction(cfunc->func_name); + } else { + LOG(FATAL) << "relay.backend.build is not registered"; + } + shape_func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); + + // Get output shapes + Array out_shapes; + for (auto out_tensor : outputs) { + int64_t* shape_data = reinterpret_cast(out_tensor->data); + Shape out_shape; + for (int i = 0; i < out_tensor->shape[0]; ++i) { + out_shape.push_back(tvm::Integer(shape_data[i])); + } + out_shapes.push_back(out_shape); + } + return out_shapes; + } + + Value InvokePrimitiveOp(const Function& func, const Array& args) { auto call_node = func->body.as(); @@ -394,17 +512,46 @@ class Interpreter : return out_tensor; }; + Array out_shapes; + auto ret_type = func->body->checked_type(); + bool is_dyn = IsDynamic(func->checked_type()); + if (call_node->op == Op::Get("shape_of")) { + // The output shape of shape_of must be static since Relay doesn't support + // dynamic rank tensors. + is_dyn = false; + } + + if (is_dyn) { + CHECK(func->IsPrimitive()); + out_shapes = ComputeDynamicShape(func, args); + } + PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_)); TVMRetValue rv; if (const TupleTypeNode* rtype = func->body->checked_type().as()) { + CHECK(!is_dyn || out_shapes.size() == rtype->fields.size()); Array fields; for (size_t i = 0; i < rtype->fields.size(); ++i) { - fields.push_back(fset_output(i, rtype->fields[i])); + if (is_dyn) { + auto sh = out_shapes[i]; + auto tt = Downcast(rtype->fields[i]); + fields.push_back(fset_output(i, TensorTypeNode::make(sh, tt->dtype))); + } else { + fields.push_back(fset_output(i, rtype->fields[i])); + } } packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); return TupleValueNode::make(fields); } else { - Value out_tensor = fset_output(0, func->body->checked_type()); + Value out_tensor; + if (is_dyn) { + CHECK_EQ(out_shapes.size(), 1); + auto sh = out_shapes[0]; + auto tt = Downcast(ret_type); + out_tensor = fset_output(0, TensorTypeNode::make(sh, tt->dtype)); + } else { + out_tensor = fset_output(0, ret_type); + } packed_func.CallPacked(TVMArgs(values.data(), codes.data(), arg_len), &rv); return out_tensor; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 17de083..5e5bc1a 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -23,12 +23,15 @@ * \brief A compiler from relay::Module to the VM byte code. */ +#include #include #include #include #include #include #include +#include +#include #include #include #include @@ -61,42 +64,6 @@ using namespace relay::transform; // (@jroesch): VM passes, eventually declare as passes. bool IsClosure(const Function& func); -// Compute the constant pool, i.e a mapping from Constant node to constant index. -struct ConstantPool : ExprVisitor { - std::set visited; - Module module; - ConstMap const_map; - ConstTensorShapeMap const_tensor_shape_map; - - size_t index; - - explicit ConstantPool(const Module& mod) : module(mod), const_map(), index(0) {} - - void VisitExpr_(const GlobalVarNode* var_node) { - auto gvar = GetRef(var_node); - if (visited.find(gvar) == visited.end()) { - visited.insert(gvar); - this->VisitExpr(this->module->Lookup(gvar)); - } - } - - void VisitExpr_(const ConstantNode* const_node) { - auto konst = GetRef(const_node); - auto it = this->const_map.find(konst); - if (it == this->const_map.end()) { - this->const_map.insert({konst, index++}); - } - } -}; - -std::tuple LayoutConstantPool(const Module& module) { - auto cp = ConstantPool(module); - for (auto& func : module->functions) { - cp.VisitExpr(func.first); - } - return std::make_tuple(cp.const_map, cp.const_tensor_shape_map); -} - void InstructionPrint(std::ostream& os, const Instruction& instr); // Represent a runtime object that's going to be matched by pattern match expressions @@ -220,12 +187,13 @@ TreeNodePtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array class VMFunctionCompiler : ExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets) + VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) : last_register_(0), registers_num_(0), engine_(CompileEngine::Global()), context_(context), - targets_(targets) {} + targets_(targets), + target_host_(target_host) {} VMFunction Compile(const GlobalVar& var, const Function& func) { size_t i = 0; @@ -288,10 +256,9 @@ class VMFunctionCompiler : ExprFunctor { } void VisitExpr_(const ConstantNode* const_node) { - auto rconst = GetRef(const_node); - auto it = this->context_->const_map.find(rconst); - CHECK(it != this->context_->const_map.end()); - Emit(Instruction::LoadConst(it->second, NewRegister())); + size_t konst_idx = context_->constants.size(); + context_->constants.push_back(const_node->data); + Emit(Instruction::LoadConst(konst_idx, NewRegister())); } void VisitExpr_(const VarNode* var_node) { @@ -326,7 +293,7 @@ class VMFunctionCompiler : ExprFunctor { } void VisitExpr_(const LetNode* let_node) { - DLOG(INFO) << let_node->value; + DLOG(INFO) << AsText(let_node->value); this->VisitExpr(let_node->value); var_register_map_.insert({let_node->var, this->last_register_}); this->VisitExpr(let_node->body); @@ -393,29 +360,206 @@ class VMFunctionCompiler : ExprFunctor { this->last_register_ = true_register; } - Instruction AllocTensorFromType(const TensorTypeNode* ttype) { - TVMType dltype = Type2TVMType(ttype->dtype); - auto tensor_type = GetRef(ttype); + Index EmitGetShape(const TensorTypeNode* ttype, Index reg) { + bool const_shape = true; std::vector shape; - for (auto dim : tensor_type->shape) { - shape.push_back(Downcast(dim)->value); + for (auto dim : ttype->shape) { + if (auto kdim = dim.as()) { + shape.push_back(kdim->value); + } else { + const_shape = false; + } + } + if (const_shape) { + int64_t ndim = shape.size(); + DLContext cpu_ctx; + cpu_ctx.device_type = kDLCPU; + cpu_ctx.device_id = 0; + NDArray shape_tensor; + if (ndim == 0) { + shape_tensor = NDArray::Empty({}, Type2TVMType(Int(64)), cpu_ctx); + } else { + shape_tensor = NDArray::Empty({ndim}, Type2TVMType(Int(64)), cpu_ctx); + int64_t* dims = reinterpret_cast(shape_tensor->data); + for (size_t i = 0; i < shape.size(); ++i) { + dims[i] = shape[i]; + } + } + size_t konst_idx = context_->constants.size(); + context_->constants.push_back(shape_tensor); + Emit(Instruction::LoadConst(konst_idx, NewRegister())); + return last_register_; + } + // For dynamic shape, we need insert shape_of op to get its shape at runtime + auto attrs = make_node(); + attrs->dtype = Int(64); + static const Op& op = Op::Get("shape_of"); + auto input = VarNode::make("input", GetRef(ttype)); + auto expr = CallNode::make(op, {input}, Attrs(attrs), {}); + auto func = FunctionNode::make({input}, expr, IncompleteTypeNode::make(Kind::kType), {}); + auto mod = ModuleNode::make({}, {}); + auto main_gv = GlobalVarNode::make("main"); + mod->Add(main_gv, func); + func = mod->Lookup(main_gv); + + // shape_of op has to be run on the host target + // TODO(@icemelon9): handle heterogeneous target, such as cuda + auto key = CCacheKeyNode::make(func, target_host_); + auto cfunc = engine_->Lower(key); + auto op_index = -1; + if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) { + op_index = context_->cached_funcs.size(); + context_->cached_funcs.push_back(cfunc); + context_->seen_funcs[cfunc->funcs[0]] = op_index; + } else { + op_index = context_->seen_funcs[cfunc->funcs[0]]; + } + std::vector arg_regs{reg}; + int64_t ndim = ttype->shape.size(); + if (ndim == 0) { + Emit(Instruction::AllocTensor({}, Int(64), NewRegister())); + } else { + Emit(Instruction::AllocTensor({ndim}, Int(64), NewRegister())); + } + Index shape_reg = last_register_; + arg_regs.push_back(shape_reg); + Emit(Instruction::InvokePacked(op_index, 2, 1, arg_regs)); + return shape_reg; + } + + std::vector EmitShapeFunc(const Type& ret_type, const Function& func, + const std::vector& unpacked_arg_regs) { + // Find the mapping from params to registers + int idx = 0; + std::vector> param_regs; + std::vector> param_types; + for (auto param : func->params) { + auto ty = param->checked_type(); + std::vector regs; + std::vector types; + if (auto ttype = ty.as()) { + regs.push_back(unpacked_arg_regs[idx++]); + types.push_back(ttype); + } else if (const auto tuple_ty = ret_type.as()) { + for (size_t j = 0; j < tuple_ty->fields.size(); ++j, ++idx) { + regs.push_back(unpacked_arg_regs[idx]); + auto ttype = tuple_ty->fields[j].as(); + CHECK(ttype); + types.push_back(ttype); + } + } else { + LOG(FATAL) << "unsupported parameter type " << ty; + } + param_regs.push_back(regs); + param_types.push_back(types); + } + + // Lower shape function + auto key = CCacheKeyNode::make(func, target_host_); + auto cfunc = engine_->LowerShapeFunc(key); + int op_index = -1; + if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) { + op_index = context_->cached_funcs.size(); + context_->cached_funcs.push_back(cfunc); + context_->seen_funcs[cfunc->funcs[0]] = op_index; + } else { + op_index = context_->seen_funcs[cfunc->funcs[0]]; + } + + // Prepare input and output registers + std::vector shape_func_args; + std::vector shape_regs; + for (size_t i = 0; i < func->params.size(); ++i) { + int state = cfunc->shape_func_param_states[i]->value; + if (state & kNeedInputData) { + for (auto reg : param_regs[i]) { + // TODO(@icemelon9): Need to copy data here for heterogeneous exec + shape_func_args.push_back(reg); + } + } + if (state & kNeedInputShape) { + for (size_t j = 0; j < param_regs[i].size(); ++j) { + shape_func_args.push_back(EmitGetShape(param_types[i][j], param_regs[i][j])); + } + } + } + for (auto t : cfunc->outputs) { + int64_t ndim = t->shape[0].as()->value; + Emit(Instruction::AllocTensor({ndim}, t->dtype, NewRegister())); + shape_func_args.push_back(last_register_); + shape_regs.push_back(last_register_); + } + + int arity = shape_func_args.size(); + int ret_count = shape_regs.size(); + Emit(Instruction::InvokePacked(op_index, arity, ret_count, shape_func_args)); + + // Alloc return tensors given the shape regs + std::vector ret_dtypes; + if (const auto* tuple_type = ret_type.as()) { + for (auto field : tuple_type->fields) { + const TensorTypeNode* tty = field.as(); + CHECK(tty); + ret_dtypes.push_back(tty->dtype); + } + } else { + auto tty = ret_type.as(); + CHECK(tty); + ret_dtypes.push_back(tty->dtype); + } + std::vector ret_regs; + for (size_t i = 0; i < shape_regs.size(); ++i) { + Emit(Instruction::AllocTensorReg(shape_regs[i], ret_dtypes[i], NewRegister())); + ret_regs.push_back(last_register_); + } + return ret_regs; + } + + std::vector AllocReturnType(const Type& ret_type, const Function& func, + const std::vector& unpacked_arg_regs) { + auto op = func->body.as()->op; + // 1. If either func param types or ret type is dynamic, we need to insert + // shape func to perform type checking at runtime. + // 2. We skip the shape_of function since currently Relay doesn't support + // dynamic rank tensor. + if (op != Op::Get("shape_of") && IsDynamic(func->checked_type())) { + return EmitShapeFunc(ret_type, func, unpacked_arg_regs); } - return Instruction::AllocTensor(shape, dltype, NewRegister()); + std::vector ret_regs; + auto alloc_tensor = [&](const TensorTypeNode* ttype) { + const TensorType& tensor_type = GetRef(ttype); + std::vector shape; + for (auto dim : tensor_type->shape) { + shape.push_back(Downcast(dim)->value); + } + Emit(Instruction::AllocTensor(shape, Type2TVMType(tensor_type->dtype), NewRegister())); + ret_regs.push_back(last_register_); + }; + if (const TensorTypeNode* ttype = ret_type.as()) { + alloc_tensor(ttype); + } else if (const TupleTypeNode* ttype = ret_type.as()) { + for (auto field : ttype->fields) { + alloc_tensor(field.as()); + } + } else { + LOG(FATAL) << "Unsupported return value type"; + } + return ret_regs; } void EmitInvokePrimitive(const Function& func, - const std::vector& args_registers, + const std::vector& arg_registers, const Type& ret_type) { std::vector unpacked_arg_regs; std::vector allocs; // Arity calculation must flatten tuples. size_t arity = 0; - CHECK_EQ(func->params.size(), args_registers.size()); + CHECK_EQ(func->params.size(), arg_registers.size()); for (size_t i = 0; i < func->params.size(); i++) { auto ty = func->params[i]->checked_type(); if (ty.as()) { - unpacked_arg_regs.push_back(args_registers[i]); + unpacked_arg_regs.push_back(arg_registers[i]); arity += 1; } else if (auto tuple_ty = ty.as()) { for (size_t f = 0; f < tuple_ty->fields.size(); f++) { @@ -424,7 +568,7 @@ class VMFunctionCompiler : ExprFunctor { << "only supports non-nested tuples currently " << "found " << field; auto dst = NewRegister(); - Emit(Instruction::GetField(args_registers[i], f, dst)); + Emit(Instruction::GetField(arg_registers[i], f, dst)); unpacked_arg_regs.push_back(dst); } arity += tuple_ty->fields.size(); @@ -433,30 +577,11 @@ class VMFunctionCompiler : ExprFunctor { } } - size_t return_val_count = 0; - if (const TensorTypeNode* ttype = ret_type.as()) { - // Allocate space for the return tensor. - auto alloc = AllocTensorFromType(ttype); - allocs.push_back(alloc); - return_val_count = 1; - } else if (const TupleTypeNode* ttype = ret_type.as()) { - std::vector fields_registers; - - for (size_t i = 0; i < ttype->fields.size(); ++i) { - auto f = ttype->fields[i]; - auto f_type = f.as(); - allocs.push_back(AllocTensorFromType(f_type)); - fields_registers.push_back(allocs.back().dst); - } - return_val_count = ttype->fields.size(); - } else { - LOG(FATAL) << "Unsupported return value type"; - } - - arity += return_val_count; - for (auto& alloc : allocs) { - Emit(alloc); - unpacked_arg_regs.push_back(alloc.dst); + auto ret_regs = AllocReturnType(ret_type, func, unpacked_arg_regs); + size_t return_count = ret_regs.size(); + arity += return_count; + for (auto reg : ret_regs) { + unpacked_arg_regs.push_back(reg); } // Next generate the invoke instruction. @@ -477,22 +602,22 @@ class VMFunctionCompiler : ExprFunctor { CHECK_EQ(cfunc->funcs.size(), 1); auto op_index = -1; if (context_->seen_funcs.find(cfunc->funcs[0]) == context_->seen_funcs.end()) { - op_index = context_->lowered_funcs.size(); - context_->lowered_funcs.push_back(cfunc->funcs[0]); + op_index = context_->cached_funcs.size(); + context_->cached_funcs.push_back(cfunc); context_->seen_funcs[cfunc->funcs[0]] = op_index; } else { op_index = context_->seen_funcs[cfunc->funcs[0]]; } - Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs)); + Emit(Instruction::InvokePacked(op_index, arity, return_count, unpacked_arg_regs)); - if (return_val_count > 1) { + if (return_count > 1) { // return value is a tuple, we need to create a tuple std::vector fields_registers; - for (size_t i = arity - return_val_count; i < arity; ++i) { + for (size_t i = arity - return_count; i < arity; ++i) { fields_registers.push_back(unpacked_arg_regs[i]); } - Emit(Instruction::AllocDatatype(0, return_val_count, fields_registers, NewRegister())); + Emit(Instruction::AllocDatatype(0, return_count, fields_registers, NewRegister())); } } @@ -636,6 +761,8 @@ class VMFunctionCompiler : ExprFunctor { VMCompilerContext* context_; /*! \brief Target devices. */ TargetsMap targets_; + /*! \brief Host target. */ + Target target_host_; }; @@ -676,28 +803,18 @@ void VMCompiler::Compile(const Module& mod_ref, // in the VMFunction table. PopulateGlobalMap(); - // Next we populate constant map. - auto constant_analysis_result = LayoutConstantPool(context_.module); - context_.const_map = std::get<0>(constant_analysis_result); - context_.const_tensor_shape_map = std::get<1>(constant_analysis_result); - // Next we get ready by allocating space for // the global state. vm_->functions.resize(context_.module->functions.size()); - vm_->constants.resize(context_.const_map.size() + context_.const_tensor_shape_map.size()); - for (auto pair : context_.const_map) { - vm_->constants[pair.second] = Object::Tensor(pair.first->data); - } - - for (auto pair : context_.const_tensor_shape_map) { - vm_->constants[pair.second.first] = Object::Tensor(pair.second.second); - } + // Next we get ready by allocating space for + // the global state. + vm_->functions.resize(context_.module->functions.size()); for (auto named_func : context_.module->functions) { auto gvar = named_func.first; auto func = named_func.second; - VMFunctionCompiler func_compiler(&context_, targets_); + VMFunctionCompiler func_compiler(&context_, targets_, target_host_); auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); @@ -711,6 +828,11 @@ void VMCompiler::Compile(const Module& mod_ref, } #endif // USE_RELAY_DEBUG + // populate constants + for (auto data : context_.constants) { + vm_->constants.push_back(Object::Tensor(data)); + } + LibraryCodegen(); for (auto gv : context_.global_map) { @@ -721,11 +843,13 @@ void VMCompiler::Compile(const Module& mod_ref, Module VMCompiler::OptimizeModule(const Module& mod) { // TODO(@icemelon9): check number of targets and build config, add more optimization pass transform::Sequential seq({transform::SimplifyInference(), - transform::ToANormalForm(), transform::InlinePrimitives(), + // TODO(@wweic): FuseOps pass currently don't handle Let + // For now, we put FuseOps before ToANormalForm to enable it + transform::FuseOps(), + transform::ToANormalForm(), transform::LambdaLift(), - transform::InlinePrimitives(), - transform::FuseOps()}); + transform::InlinePrimitives()}); auto pass_ctx = transform::PassContext::Create(); tvm::With ctx(pass_ctx); return seq(mod); @@ -741,27 +865,36 @@ void VMCompiler::PopulateGlobalMap() { } void VMCompiler::LibraryCodegen() { - auto const& lowered_funcs = context_.lowered_funcs; - if (lowered_funcs.size() == 0) { + auto const &cached_funcs = context_.cached_funcs; + if (cached_funcs.size() == 0) { return; } - // TODO(@icemelon9): support heterogeneous targets - Target target; - for (auto kv : targets_) { - target = kv.second; + std::unordered_map> tgt_funcs; + for (auto &cfunc : cached_funcs) { + std::string target_str = cfunc->target->str(); + if (tgt_funcs.count(target_str) == 0) { + tgt_funcs.emplace(target_str, Array{cfunc->funcs[0]}); + } else { + tgt_funcs[target_str].push_back(cfunc->funcs[0]); + } } - if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - runtime::Module mod = - (*f)(tvm::Array(lowered_funcs.begin(), lowered_funcs.end()), target, - target_host_); + Map> funcs; + for (auto &it : tgt_funcs) { + funcs.Set(Target::Create(it.first), it.second); + } + + if (const auto *f = runtime::Registry::Get("relay.backend.build")) { + // The target is just a dummy arg because funcs already contains corresponding target + // therefore target won't be used in the build function + runtime::Module mod = (*f)(funcs, Target(), target_host_); CHECK(mod.operator->()); vm_->lib = mod; } else { LOG(FATAL) << "relay.backend.build is not registered"; } size_t primitive_index = 0; - for (auto lfunc : lowered_funcs) { - vm_->primitive_map.insert({lfunc->name, primitive_index++}); + for (auto cfunc : cached_funcs) { + vm_->primitive_map.insert({cfunc->funcs[0]->name, primitive_index++}); } } diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index 4a2de0a..bfe19ac 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -72,12 +72,10 @@ struct VMCompilerContext { TagMap tag_map; // Map from global var to a unique integer GlobalMap global_map; - // Map from Const object to its index in const pool - ConstMap const_map; - // Map from Const tensor shape to its index in const pool - ConstTensorShapeMap const_tensor_shape_map; - // List of lowered functions - std::vector lowered_funcs; + // List of constants + std::vector constants; + // List of cached functions + std::vector cached_funcs; // The functions that have been lowered. std::unordered_map seen_funcs; }; diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index b4303e7..76b56ae 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -121,14 +121,25 @@ TVM_REGISTER_API("relay.op._ListOpNames") TVM_REGISTER_API("relay.op._GetOp").set_body_typed(Op::Get); TVM_REGISTER_API("relay.op._OpGetAttr") - .set_body([](TVMArgs args, TVMRetValue* rv) { - Op op = args[0]; - std::string attr_name = args[1]; - auto op_map = Op::GetAttr(attr_name); - if (op_map.count(op)) { - *rv = op_map[op]; - } - }); +.set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + auto op_map = Op::GetAttr(attr_name); + if (op_map.count(op)) { + *rv = op_map[op]; + } + }); + +TVM_REGISTER_API("relay.op._OpSetAttr") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Op op = args[0]; + std::string attr_name = args[1]; + runtime::TVMArgValue value = args[2]; + int plevel = args[3]; + auto& reg = + OpRegistry::Registry()->__REGISTER_OR_GET__(op->name).set_name(); + reg.set_attr(attr_name, value, plevel); + }); TVM_REGISTER_API("relay.op._Register") .set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index c3975c3..6305d22 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -528,7 +528,11 @@ bool ReshapeRel(const Array& types, used_input_dims.insert(src_idx); IndexExpr d2 = data_shape[src_idx++]; used_output_dims.insert(oshape.size()); - oshape.push_back(d1 * d2); + if (d1.as() || d2.as()) { + oshape.push_back(Any::make()); + } else { + oshape.push_back(d1 * d2); + } } else if (svalue == -4) { // split the source dim s into two dims // read the left dim and then the right dim (either can be -1) @@ -563,6 +567,8 @@ bool ReshapeRel(const Array& types, oshape.push_back(d2); } } + } else { + CHECK(false) << "Unsupported special value: " << svalue; } } @@ -608,7 +614,15 @@ Array ReshapeCompute(const Attrs& attrs, const Target& target) { const auto* out_ttype = out_type.as(); CHECK(out_ttype != nullptr); - return { topi::reshape(inputs[0], out_ttype->shape) }; + Array newshape; + for (auto val : out_ttype->shape) { + if (val->is_type()) { + newshape.push_back(val.as()->ToVar()); + } else { + newshape.push_back(val); + } + } + return { topi::reshape(inputs[0], newshape) }; } Expr MakeReshape(Expr data, @@ -1108,7 +1122,8 @@ RELAY_REGISTER_OP("arange") .set_support_level(3) .add_type_rel("Arange", ArangeRel) .set_attr("FTVMCompute", ArangeCompute) -.set_attr("TOpPattern", kInjective) +// TODO(@icemelon): Change arange to kOpaque because FuseOps doesn't consider dynamic shape +.set_attr("TOpPattern", kOpaque) .set_attr("AnyCodegenStrategy", kVariableDimensions); // repeat operator diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index d3afb91..826fe69 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -295,7 +295,9 @@ RELAY_REGISTER_OP("shape_of") .add_argument("data", "Tensor", "The input tensor.") .add_type_rel("ShapeOf", ShapeOfRel) .set_attr("TOpIsStateful", false) -.set_attr("TOpPattern", kInjective) +// Use kOpaque for shape_of op for now since it won't be performance critic, +// and it makes things easier for dynamic shape func +.set_attr("TOpPattern", kOpaque) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_support_level(10) diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index d4efe80..f71b85d 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -81,16 +81,17 @@ Type ConcreteBroadcast(const TensorType& t1, for (; i <= std::min(ndim1, ndim2); ++i) { IndexExpr s1 = t1->shape[ndim1 - i]; IndexExpr s2 = t2->shape[ndim2 - i]; - if (EqualCheck(s1, s2)) { - oshape.push_back(s1); - } else if (EqualConstInt(s1, 1)) { + if (EqualConstInt(s1, 1)) { oshape.push_back(s2); } else if (EqualConstInt(s2, 1)) { oshape.push_back(s1); - } else if (s1.as() && EqualConstInt(s2, 1)) { - // TODO(@jroesch): we need to come back to this + } else if (s1.as()) { + // s1 == 1 || s1 == s2 oshape.push_back(s2); - } else if (s2.as() && EqualConstInt(s1, 1)) { + } else if (s2.as()) { + // s2 == 1 || s2 == s1 + oshape.push_back(s1); + } else if (EqualCheck(s1, s2)) { oshape.push_back(s1); } else { RELAY_ERROR( diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index cdd2837..9dc180f 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -915,7 +915,7 @@ class FuseMutator : private ExprMutator { if (it == gmap_.end()) return ""; std::ostringstream os; auto *group = it->second->FindRoot(); - os << "group=" << group; + os << " /* group=" << group << " */"; return os.str(); }); LOG(INFO) << "Dump of group info:\n" << text; diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index cef8e72..153b90c 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -120,7 +120,7 @@ class ModulePassNode : public PassNode { /*! * \brief Get the pass information/meta data. */ - PassInfo Info() const { return pass_info; } + PassInfo Info() const override { return pass_info; } TVM_DLL static ModulePass make( runtime::TypedPackedFunc pass_func, @@ -174,7 +174,7 @@ class FunctionPassNode : public PassNode { /*! * \brief Get the pass information/meta data. */ - PassInfo Info() const { return pass_info; } + PassInfo Info() const override { return pass_info; } TVM_DLL static FunctionPass make( runtime::TypedPackedFunc pass_func, @@ -220,7 +220,7 @@ class SequentialNode : public PassNode { /*! * \brief Get the pass information/meta data. */ - PassInfo Info() const { return pass_info; } + PassInfo Info() const override { return pass_info; } /*! * \brief Check if a pass is enabled. diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 33990ae..02ea3a4 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -451,11 +451,11 @@ std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") { void InstructionPrint(std::ostream& os, const Instruction& instr) { switch (instr.op) { case Opcode::Move: { - os << "move $" << instr.dst << " $" << instr.from << std::endl; + os << "move $" << instr.dst << " $" << instr.from; break; } case Opcode::Ret: { - os << "ret $" << instr.result << std::endl; + os << "ret $" << instr.result; break; } case Opcode::Fatal: { @@ -469,7 +469,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { << ", out: $" << StrJoin(instr.packed_args, instr.arity - instr.output_size, instr.output_size, ", $") - << ")" << std::endl; + << ")"; break; } case Opcode::AllocTensor: { @@ -478,71 +478,61 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); - os << std::endl; break; } case Opcode::AllocTensorReg: { os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); - os << std::endl; break; } case Opcode::AllocDatatype: { os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$" - << StrJoin(instr.datatype_fields, 0, instr.num_fields, ",$") << "]" - << std::endl; + << StrJoin(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"; break; } case Opcode::AllocClosure: { os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index << "]($" << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") - << ")" - << std::endl; + << ")"; break; } case Opcode::If: { os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " " - << instr.if_op.true_offset << " " << instr.if_op.false_offset - << std::endl; + << instr.if_op.true_offset << " " << instr.if_op.false_offset; break; } case Opcode::Invoke: { os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($" << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") - << ")" - << std::endl; + << ")"; break; } case Opcode::InvokeClosure: { os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($" << StrJoin(instr.closure_args, 0, instr.num_closure_args, ",$") - << ")" - << std::endl; + << ")"; break; } case Opcode::LoadConst: { - os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]" - << std::endl; + os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"; break; } case Opcode::LoadConsti: { - os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]" - << std::endl; + os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]"; break; } case Opcode::GetField: { os << "get_field $" << instr.dst << " $" << instr.object << "[" - << instr.field_index << "]" - << std::endl; + << instr.field_index << "]"; break; } case Opcode::GetTag: { - os << "get_tag $" << instr.dst << " $" << instr.get_tag.object << std::endl; + os << "get_tag $" << instr.dst << " $" << instr.get_tag.object; break; } case Opcode::Goto: { - os << "goto " << instr.pc_offset << std::endl; + os << "goto " << instr.pc_offset; break; } default: @@ -559,9 +549,7 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instr) { void VMFunctionPrint(std::ostream& os, const VMFunction& vm_func) { os << vm_func.name << ": " << std::endl; for (size_t i = 0; i < vm_func.instructions.size(); ++i) { - os << i << ": "; - InstructionPrint(os, vm_func.instructions[i]); - os << ";" << std::endl; + os << i << ": " << vm_func.instructions[i] << ";" << std::endl; } } @@ -801,7 +789,7 @@ void VirtualMachine::RunLoop() { while (true) { main_loop: auto const& instr = this->code[this->pc]; - DLOG(INFO) << "Executing(" << pc << "): "; + DLOG(INFO) << "Executing(" << pc << "): " << instr; #if USE_RELAY_DEBUG InstructionPrint(std::cout, instr); #endif // USE_RELAY_DEBUG diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 760ed0f..31f6169 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -546,6 +546,8 @@ void InjectInline(ScheduleNode* sch) { std::vector > new_body(sch->stages.size()); std::vector changed(sch->stages.size(), false); + std::vector new_hybrid_body(sch->stages.size()); + std::vector hybrid_changed(sch->stages.size(), false); // inline all the ops for (size_t i = sch->stages.size(); i != 0; --i) { Stage stage = sch->stages[i - 1]; @@ -568,6 +570,7 @@ void InjectInline(ScheduleNode* sch) { for (size_t j = i; j < sch->stages.size(); ++j) { Stage s = sch->stages[j]; const ComputeOpNode* compute = s->op.as(); + const HybridOpNode* hybrid = s->op.as(); if (compute) { if (!new_body[j].size()) { new_body[j] = compute->body; @@ -606,6 +609,15 @@ void InjectInline(ScheduleNode* sch) { } } } + } else if (hybrid) { + if (!new_hybrid_body[j].defined()) { + new_hybrid_body[j] = hybrid->body; + } + Stmt new_stmt = ir::Inline(new_hybrid_body[j], stage->op, args, body); + if (!new_stmt.same_as(new_hybrid_body[j])) { + new_hybrid_body[j] = new_stmt; + hybrid_changed[j] = true; + } } } } @@ -632,6 +644,17 @@ void InjectInline(ScheduleNode* sch) { } s->op = op; } + } else if (hybrid_changed[i]) { + const HybridOpNode* hybrid = sch->stages[i]->op.as(); + CHECK(hybrid); + Operation op = HybridOpNode::make( + hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, + hybrid->outputs, new_hybrid_body[i]); + op = op->ReplaceInputs(op, repl); + for (int idx = 0; idx < s->op->num_outputs(); ++idx) { + repl[s->op.output(idx)] = op.output(idx); + } + s->op = op; } else { Operation op = s->op->ReplaceInputs(s->op, repl); if (!op.same_as(s->op)) { diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index f0a3fe7..214b88f 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -18,27 +18,156 @@ import numpy as np import tvm from tvm import relay -from tvm.relay import Kind, transform from tvm.relay.loops import while_loop from tvm.relay.testing import run_infer_type as infer_type def int32(val): return relay.const(val, 'int32') +def any_dims(ndim): + shape = [] + for _ in range(ndim): + shape.append(relay.Any()) + return tuple(shape) + +# TODO(@wweic): because vm doesn't support heterogeneous exec, we can only test +# shape function on CPU. + +def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): + dtype = 'float32' + x = relay.var('x', shape=x_shape, dtype=dtype) + y = relay.var('y', shape=y_shape, dtype=dtype) + mod = relay.module.Module() + mod["main"] = relay.Function([x, y], op(x, y)) + x_np = np.random.uniform(size=x_np_shape).astype(dtype) + y_np = np.random.uniform(size=y_np_shape).astype(dtype) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, y_np) + tvm.testing.assert_allclose(result.asnumpy(), np_op(x_np, y_np)) + +def test_any_broadcast(): + verify_any_broadcast((relay.Any(),), (3, 2), (1,), (3, 2), relay.add, np.add) + verify_any_broadcast((relay.Any(), 2), (1, 2), (1, 2), (1, 2), relay.add, np.add) + verify_any_broadcast((relay.Any(), 2), (1, 2), (3, 2), (1, 2), relay.add, np.add) + verify_any_broadcast((relay.Any(), 2), (3, 2), (1, 2), (3, 2), relay.add, np.add) + verify_any_broadcast((relay.Any(), 2), (3, relay.Any()), (1, 2), (3, 1), relay.add, np.add) + + # The following currently fail because topi compute treats Any as 1 + # will requires auto_broadcast buffer to solve the problem + # TODO(@zhiics): Fix this + # verify_any_broadcast((relay.Any(),), (3, 2), (2,), (3, 2), relay.add, np.add) + # verify_any_broadcast((relay.Any(), 2), (3, 2), (3, 2), (3, 2), relay.add, np.add) + +def test_any_concat(): + x = relay.var('x', shape=(relay.Any(), 2), dtype="float32") + y = relay.var('y', shape=(1, 2), dtype="float32") + z = relay.op.concatenate([x, y], axis=0) + mod = relay.module.Module() + mod["main"] = relay.Function([x, y], z) + x_np = np.random.uniform(size=(3, 2)).astype('float32') + y_np = np.random.uniform(size=(1, 2)).astype('float32') + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(x_np, y_np) + ref = np.concatenate([x_np, y_np], axis=0) + tvm.testing.assert_allclose(result.asnumpy(), ref) + +def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape): + x = relay.var('x', shape=x_shape, dtype="float32") + y = relay.reshape(x, newshape=newshape) + mod = relay.module.Module() + mod["main"] = relay.Function([x], y) + data = np.random.uniform(size=x_np_shape).astype('float32') + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data).asnumpy() + assert result.shape == out_shape + tvm.testing.assert_allclose(result.flatten(), data.flatten()) + +def test_any_reshape(): + verify_any_reshape(any_dims(3), (1, -1), (2, 3, 4), (1, 24)) + verify_any_reshape(any_dims(3), (0, -1), (2, 3, 4), (2, 12)) + verify_any_reshape(any_dims(3), (0, -2), (2, 3, 4), (2, 3, 4)) + verify_any_reshape(any_dims(3), (-4, 2, -1, -2), (6, 3, 4), (2, 3, 3, 4)) + verify_any_reshape(any_dims(3), (-4, -1, 2, -3), (6, 3, 4), (3, 2, 12)) + +def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_shape): + mod = relay.Module() + data = relay.var('data', shape=data_shape, dtype='float32') + indices = relay.var('indices', shape=indices_shape, dtype='int32') + y = relay.take(data, indices, axis=axis) + mod["main"] = relay.Function([data, indices], y) + data_np = np.random.uniform(size=data_np_shape).astype('float32') + if axis is None: + max_index = data_np.size + else: + max_index = data_np.shape[axis] + indices_np = np.random.randint(max_index, size=indices_np_shape).astype('int32') + ref = np.take(data_np, indices_np, axis=axis) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data_np, indices_np) + tvm.testing.assert_allclose(result.asnumpy(), ref) + +def test_any_take(): + verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,)) + verify_any_take(any_dims(2), (), 0, (4, 5), ()) + verify_any_take(any_dims(2), (), None, (4, 5), ()) + verify_any_take(any_dims(3), any_dims(2), 1, (3, 4, 5), (2, 3)) + verify_any_take(any_dims(2), any_dims(3), None, (4, 5), (2, 3, 4)) + verify_any_take(any_dims(2), any_dims(4), -1, (4, 5), (2, 3, 4, 5)) + +def test_any_shape_of(): + x = relay.var('x', shape=any_dims(2), dtype='float32') + y = relay.shape_of(x) + mod = relay.module.Module() + mod["main"] = relay.Function([x], y) + data = np.random.uniform(size=(3, 4)).astype('float32') + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data) + tvm.testing.assert_allclose(result.asnumpy(), np.array([3,4]).astype("int64")) + + x = relay.var('x', shape=any_dims(3), dtype='float32') + y0 = relay.shape_of(x) + y1 = relay.take(y0, relay.const(1, 'int32')) + mod = relay.module.Module() + mod["main"] = relay.Function([x], y1) + data = np.random.uniform(size=(2, 3, 4)).astype('float32') + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data) + tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64")) + +def test_fused_ops(): + x = relay.var('x', shape=(relay.Any(), relay.Any()), dtype='float32') + y0 = x + relay.const(1.0, 'float32') + y1 = y0 * relay.const(2.0, 'float32') + mod = relay.module.Module() + mod["main"] = relay.Function([x], y1) + data = np.random.uniform(size=(5, 4)).astype('float32') + for kind in ["vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data) + tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2) + def test_arange_with_dynamic_shape(): m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k') x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32') y0 = relay.shape_of(x) y1 = relay.take(y0, relay.const(0, 'int32')) - y2 = relay.op.arange(y1) - ex = relay.create_executor() - f = relay.Function([x], y2, type_params=[m, n, k]) - # TODO(@jroesch): Restore after code generation. - # data = np.random.rand(10, 5, 3).astype('float32') - # result = ex.evaluate(f)(data) - # np.testing.assert_allclose(result.asnumpy(), np.array(range(10))) - -def test_dynamic_concat(): + y2 = relay.op.arange(y1, dtype="int32") + y3 = y2 + relay.const(1, dtype="int32") + data = np.random.rand(10, 5, 3).astype('float32') + mod = relay.module.Module() + mod["main"] = relay.Function([x], y3, type_params=[m, n, k]) + for kind in ["debug", "vm"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data) + tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1) + +def test_recursive_concat(): """ fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) { if (%i < 10) { @@ -66,26 +195,18 @@ def test_dynamic_concat(): start = relay.var('start', shape=(), dtype='int32') body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1))) func = relay.Function([start], relay.TupleGetItem(body, 1)) - func = infer_type(func) - # TODO(@jroesch, @haichen): We should restore this code when codegeneration - # is merged - # ret_shape = func.checked_type.ret_type.shape - # assert len(ret_shape) == 2, "expected 2-dim output" - # assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any()) - # import pdb; pdb.set_trace() - # mod = relay.module.Module() - # print(relay.ir_pass.infer_type(func, mod=mod)) - # ret = relay.Call(loop, [relay.const(0, 'int32'), init]) - # mod[mod.entry_func] = relay.Function([], ret) - # print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod)) - - # initial = np.array(0.0, dtype='float32').reshape((1,)) - # iter_stop = np.array(10, dtype='int32') - # ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm") - # result = ex.evaluate(mod.entry_func)() - # np.testing.assert_allclose(result.asnumpy(), np.array(range(10))) - -def test_dynamic_concat_with_wrong_annotation(): + mod = relay.module.Module() + mod["main"] = func + data = np.array(0.0, dtype='int32') + # TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail + # so currently we cannot run this test case on VM + for kind in ["debug"]: + ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") + result = ex.evaluate()(data) + ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32") + np.testing.assert_allclose(result.asnumpy(), ref) + +def test_recursive_concat_with_wrong_annotation(): """ v0.0.1 fn (%start: int32) { @@ -133,6 +254,12 @@ def test_dynamic_concat_with_wrong_annotation(): assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) if __name__ == "__main__": + test_any_broadcast() + test_any_concat() + test_any_reshape() + test_any_take() + test_any_shape_of() + test_fused_ops() test_arange_with_dynamic_shape() - test_dynamic_concat() - test_dynamic_concat_with_wrong_annotation() + test_recursive_concat() + test_recursive_concat_with_wrong_annotation() diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 9a8ab2d..a32ec27 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -104,9 +104,6 @@ def test_serializer(): vm = create_vm(mod) ser = serializer.Serializer(vm) - stats = ser.stats - assert "scalar" in stats - glbs = ser.globals assert len(glbs) == 3 assert "f1" in glbs @@ -120,8 +117,8 @@ def test_serializer(): code = ser.bytecode assert "main 5 2 5" in code - assert "f1 3 1 4" in code - assert "f2 3 1 4" in code + assert "f1 2 1 3" in code + assert "f2 2 1 3" in code code, lib = ser.serialize() assert isinstance(code, bytearray) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 805cff8..fd40c3f 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -122,11 +122,13 @@ def test_outer_product(): assert ibody.min.value == 0 assert ibody.extent.name == 'm' #Check loop body - jbody = ibody.body + jblock = ibody.body + assert isinstance(jblock, tvm.stmt.Block) + jbody = jblock.first assert isinstance(jbody, tvm.stmt.AssertStmt) assert isinstance(jbody.message, tvm.expr.StringImm) assert jbody.message.value == "index out of range!" - jbody = jbody.body + jbody = jblock.rest assert isinstance(jbody, tvm.stmt.Provide) assert jbody.func.name == 'c' assert len(jbody.args) == 2 diff --git a/topi/include/topi/detail/broadcast.h b/topi/include/topi/detail/broadcast.h index b5120ee..2439fab 100644 --- a/topi/include/topi/detail/broadcast.h +++ b/topi/include/topi/detail/broadcast.h @@ -52,6 +52,9 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, tvm::Expr one(1); int i; for (i = 1; i <= std::min(s1_size, s2_size); ++i) { + // TODO(@icemelon9): Need to revisit this part + const Variable* var1 = shape1[s1_size - i].as(); + const Variable* var2 = shape2[s2_size - i].as(); bh.all_vars.push_front(tvm::Var()); if (topi::detail::EqualCheck(shape1[s1_size - i], shape2[s2_size - i])) { bh.common_shape.push_front(shape1[s1_size - i]); @@ -64,6 +67,16 @@ inline BroadcastHelper BroadcastShape(const tvm::Array& shape1, } else if (topi::detail::EqualCheck(one, shape2[s2_size - i])) { bh.common_shape.push_front(shape1[s1_size - i]); bh.vars1.push_front(bh.all_vars[0]); + } else if (var1 && var2) { + bh.common_shape.push_front(max(shape1[s1_size - i], shape2[s2_size - i])); + bh.vars1.push_front(bh.all_vars[0]); + bh.vars2.push_front(bh.all_vars[0]); + } else if (var1) { + bh.common_shape.push_front(shape2[s2_size - i]); + bh.vars2.push_front(bh.all_vars[0]); + } else if (var2) { + bh.common_shape.push_front(shape1[s1_size - i]); + bh.vars1.push_front(bh.all_vars[0]); } else { CHECK(false) << "Incompatible broadcast dims: " << shape1[s1_size - i] << " and " << shape2[s2_size - i] << " in: " diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 1622b20..af2ed16 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1148,9 +1148,9 @@ inline Tensor tensordot(const Tensor& A, return compute(output_shape, func, name, tag); } -inline Tensor arange(const Expr start, - const Expr stop, - const Expr step, +inline Tensor arange(const Expr& start, + const Expr& stop, + const Expr& step, Type dtype, std::string name = "T_arange", std::string tag = kInjective) { diff --git a/topi/tests/python/test_topi_conv2d_winograd.py b/topi/tests/python/test_topi_conv2d_winograd.py index a42d61d..908317e 100644 --- a/topi/tests/python/test_topi_conv2d_winograd.py +++ b/topi/tests/python/test_topi_conv2d_winograd.py @@ -82,9 +82,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func(a, w, c) - rtol = 1e-5 - if (kernel > 3): - rtol = 2e-5 + rtol = 1e-3 tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)