static constexpr const char* _type_key = "Reduce";
};
+/*! \brief Any shape. */
+struct Any : public ExprNode<Any> {
+ TVM_DLL static Expr make();
+
+ void VisitAttrs(AttrVisitor* v) final {}
+ static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
+ static constexpr const char* _type_key = "Any";
+};
+
/*!
* \brief Auxiliary data structure used in IR Pass to indicate a tensor.
*/
/*! \brief Attributes used in arange operators */
struct ArangeAttrs : public tvm::AttrsNode<ArangeAttrs> {
- tvm::Expr start;
- tvm::Expr stop;
- tvm::Expr step;
+ Expr start;
+ Expr stop;
+ Expr step;
DataType dtype;
TVM_DECLARE_ATTRS(ArangeAttrs, "relay.attrs.ArangeAttrs") {
- TVM_ATTR_FIELD(start).set_default(make_const(Float(32), 0))
+ TVM_ATTR_FIELD(start)
.describe("Start of interval. The interval includes this value.");
TVM_ATTR_FIELD(stop)
.describe("Stop of interval. The interval does not include this value.");
- TVM_ATTR_FIELD(step).set_default(make_const(Float(32), 1))
+ TVM_ATTR_FIELD(step)
.describe("Spacing between values.");
- TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
+ TVM_ATTR_FIELD(dtype)
.describe("Target data type.");
}
}; // struct ArangeAttrs
struct Error : public dmlc::Error {
Span sp;
- explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
- Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
- Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
+ explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {}
+ Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*)
+ Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {}
+ Error() : dmlc::Error(""), sp(nullptr) {}
};
/*! \brief An abstraction around how errors are stored and reported.
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
- this->ReportAt(global, node, Error(err));
+ std::string err_msg = err.str();
+ this->ReportAt(global, node, Error(err_msg));
}
/*! \brief Report an error against a program, using the full program
return node;
}
+/*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */
+std::string PrettyPrint(const NodeRef& node);
+
/*!
* \brief Render the node as a string in the Relay text format.
* \param node The node to be rendered.
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
const Expr& output_grad)>;
+/*!
+ * \brief The codegeneration strategy for dynamic dimensions.
+ */
+enum AnyCodegenStrategy {
+ /*! \brief The default strategy of using completely variable dimensions. */
+ kVariableDimensions
+};
+
+/* \brief A runtime representation of shape. */
+using Shape = Array<IndexExpr>;
+
+using FShapeFunc = runtime::TypedPackedFunc<
+ Array<Tensor>(const Attrs& attrs,
+ const Array<Tensor>& inputs,
+ const Array<Shape>& out_shapes)>;
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_ATTR_TYPES_H_
namespace tvm {
namespace relay {
+using Any = tvm::ir::Any;
+
/*! \brief Base type of the Relay type hiearchy. */
class TypeNode : public RelayNode {
public:
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
+
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
TVM_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, TVMStreamHandle stream = nullptr);
+ TVM_DLL std::vector<int64_t> Shape() const;
+
// internal namespace
struct Internal;
protected:
"""
c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg)
- if err_type.startswith("tvm.error."):
+ if err_type is not None and err_type.startswith("tvm.error."):
err_type = err_type[10:]
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
raise ValueError("nested tag is not allowed for now")
tag = _tag.TagScope.get_current().tag
shape = (shape,) if isinstance(shape, (_expr.Expr, _Integral)) else shape
- shape = [shape] if isinstance(shape[0], (_expr.Expr, _Integral)) else shape
+ if shape == () or isinstance(shape[0], (_expr.Expr, _Integral)):
+ shape = [shape]
if in_buffers is not None:
in_buffers = [in_buffers] if not isinstance(in_buffers, list) else in_buffers
if len(inputs) != len(in_buffers):
TensorType = ty.TensorType
Kind = ty.Kind
TypeVar = ty.TypeVar
+ShapeVar = ty.ShapeVar
TypeConstraint = ty.TypeConstraint
FuncType = ty.FuncType
TypeRelation = ty.TypeRelation
RefType = ty.RefType
GlobalTypeVar = ty.GlobalTypeVar
TypeCall = ty.TypeCall
+Any = ty.Any
# Expr
Expr = expr.Expr
"""
if isinstance(value, (_base.numeric_types, (bool, list))):
value = _np.array(value, dtype=dtype)
+
if not dtype:
# when dtype is None: int maps to "int32", float maps to "float32"
map_dtype = {
}.get(value.dtype, None)
if map_dtype:
value = value.astype(map_dtype)
+
if isinstance(value, (_np.ndarray, _np.generic)):
value = _nd.array(value)
raise tvm.error.OpAttributeUnimplemented(
'Attribute "repeat" is not supported in operator arange.')
new_attrs = {}
- new_attrs["start"] = attrs.get_float("start", 0)
- new_attrs["stop"] = attrs.get_float("stop")
- new_attrs["step"] = attrs.get_float("step", 1)
+ new_attrs["start"] = _expr.const(attrs.get_float("start", 0.0))
+ new_attrs["stop"] = _expr.const(attrs.get_float("stop"))
+ new_attrs["step"] = _expr.const(attrs.get_float("step", 1.0))
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.arange(**new_attrs)
return AttrCvt(
op_name="arange",
ignores=['Tidx'],
- extras={'start': start,
- "stop": limit,
- 'step': delta,
+ extras={'start': _expr.const(start),
+ "stop": _expr.const(limit),
+ 'step': _expr.const(delta),
'dtype': dtype})([], attr)
return _impl
crop = crops[axis - 1]
if crop != [0, 0]:
indices = tvm.relay.arange(
- crop[0],
- reshaped_permuted_shape[axis] - crop[1],
+ _expr.const(crop[0]),
+ _expr.const(reshaped_permuted_shape[axis] - crop[1]),
dtype='int32'
)
cropped = tvm.relay.take(cropped, indices=indices, axis=axis)
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""
+Utilities for building Relay loops.
+"""
+from .scope_builder import ScopeBuilder
+from . import expr as _expr
+
+def while_loop(cond, loop_vars, loop_bodies):
+ """
+ Construct a while loop.
+
+ Parameters
+ ----------
+
+ cond: Callable[Tuple[relay.Expr], relay.Expr]
+ The condition of the loop.
+
+ loop_vars: Tuple[relay.Expr]
+ The variables being looped over.
+ The initial values of the loop, will be used to
+ construct the loop variables.
+
+ loop_bodies: Callable[Tuple[relay.Expr], Tuple[relay.Expr]]
+ The body of the loop, should be a function which
+ given loop variables produces the output result
+ also as a tuple
+
+ Returns
+ -------
+ loop: relay.Expr
+ The loop expression.
+ """
+ sb = ScopeBuilder()
+ loop = _expr.Var("while_loop")
+ fresh_vars = []
+
+ for i, loop_var in enumerate(loop_vars):
+ name = loop_var.name_hint if isinstance(loop_var, _expr.Var) else "arg{}".format(i)
+ new_var = _expr.var(name, type_annotation=sb.type_of(loop_var))
+ fresh_vars.append(new_var)
+
+ with sb.if_scope(cond(*fresh_vars)):
+ sb.ret(loop(*loop_bodies(*fresh_vars)))
+ with sb.else_scope():
+ sb.ret(_expr.Tuple(fresh_vars))
+
+ func = _expr.Function(fresh_vars, sb.get())
+ let = _expr.Let(loop, func, loop)
+ return let
"""Transform operators."""
from . import _make
-from ..expr import TupleWrapper
+from ..expr import TupleWrapper, const
def cast(data, dtype):
return _make.full_like(data, fill_value)
-def arange(start, stop=None, step=1, dtype="float32"):
+def arange(start, stop=None, step=None, dtype="float32"):
"""Return evenly spaced values within a given interval.
.. note::
relay.arange(1, 5) = [1, 2, 3, 4]
relay.arange(1, 5, 1.5) = [1, 2.5, 4]
"""
+ if step is None:
+ step = const(1, dtype)
+
if stop is None:
stop = start
- start = 0
+ start = const(0, dtype=dtype)
+
return _make.arange(start, stop, step, dtype)
else:
self._exit_cb()
-
def _make_lets(bindings, ret_value):
"""Make a nested let expressions.
false_branch)
return WithScope(None, _on_exit)
+
+ def type_of(self, expr):
+ """
+ Compute the type of an expression.
+
+ Parameters
+ ----------
+ expr: relay.Expr
+ The expression to compute the type of.
+ """
+ if isinstance(expr, _expr.Var):
+ return expr.type_annotation
+
+ ity = _ty.IncompleteType()
+ var = _expr.var("unify", ity)
+ self.let(var, expr)
+ return ity
+
def ret(self, value):
"""Set the return value of this scope.
from .base import RelayNode, register_relay_node
from . import _make
+Any = _make.Any
class Type(RelayNode):
"""The base type for all Relay types."""
"""
self.__init_handle_by_constructor__(_make.TypeVar, var, kind)
+def ShapeVar(name):
+ """A helper which constructs a type var of which the shape kind.
+
+ Parameters
+ ----------
+ name : str
+
+ Returns
+ -------
+ type_var : tvm.relay.TypeVar
+ The shape variable.
+ """
+ return TypeVar(name, kind=Kind.ShapeVar)
@register_relay_node
class GlobalTypeVar(Type):
op->call_type == Call::PureExtern) {
return CreateCallExtern(op);
} else {
- LOG(FATAL) << "Unknown call type ";
+ LOG(FATAL) << "Unknown call type " <<
+ "name= " << op->name <<
+ " call_type= " << op->call_type;
return nullptr;
}
}
inline Expr ElemOffset(const BufferNode* n, Array<Expr> index) {
Expr base = n->elem_offset;
if (n->strides.size() == 0) {
- CHECK_EQ(n->shape.size(), index.size());
- if (index.size() > 0) {
- Expr offset = index[0];
- for (size_t i = 1; i < index.size(); ++i) {
- offset = MergeMulMod(offset * n->shape[i] + index[i]);
+ // Scalar case
+ if (n->shape.size() == 0 && index.size() == 1) {
+ auto is_int = index[0].as<IntImm>();
+ CHECK(is_int && is_int->value == 0);
+ base = base + index[0];
+ } else {
+ CHECK_EQ(n->shape.size(), index.size());
+ if (index.size() > 0) {
+ Expr offset = index[0];
+ for (size_t i = 1; i < index.size(); ++i) {
+ offset = MergeMulMod(offset * n->shape[i] + index[i]);
+ }
+ base = base + offset;
}
- base = base + offset;
}
} else {
CHECK_EQ(n->strides.size(), index.size());
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
using tvm::ir::CommReducerNode;
using tvm::ir::Reduce;
+using tvm::ir::Any;
using tvm::ir::AttrStmt;
template<>
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
- LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
+ LOG(FATAL) << "Reduce does not work with old Visitor, use IRFunctor style visitor";
+}
+
+template<>
+void ExprNode<Any>::accept(IRVisitor *v, const Expr&) const {
+ LOG(FATAL) << "Any does not work with old Visitor, use IRFunctor style visitor";
}
+TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
+.set_dispatch<Any>([](const Any *op, IRPrinter *p) {
+ p->stream << "?";
+});
+
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
p->stream << "reduce(combiner="
return Expr(n);
}
+Expr Any::make() {
+ auto n = make_node<Any>();
+ return Expr(n);
+}
+
TVM_REGISTER_NODE_TYPE(CommReducerNode);
TVM_REGISTER_NODE_TYPE(Reduce);
+TVM_REGISTER_NODE_TYPE(Any);
TVM_REGISTER_NODE_TYPE(AttrStmt);
TVM_REGISTER_NODE_TYPE(FloatImm);
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expr Tensor::operator()(Array<Expr> indices) const {
using HalideIR::Internal::Call;
- CHECK_EQ(ndim(), indices.size())
- << "Tensor dimension mismatch in read"
- << "ndim = " << ndim() << ", indices.size=" << indices.size();
+ if (ndim() != 0) {
+ CHECK_EQ(ndim(), indices.size())
+ << "Tensor dimension mismatch in read"
+ << "ndim = " << ndim() << ", indices.size=" << indices.size();
+ }
+
auto n = Call::make(
(*this)->dtype, (*this)->op->name, indices, Call::Halide,
(*this)->op, (*this)->value_index);
}
/*!
- * \brief Build relay function to runtime module
+ * \brief Compile a Relay function to runtime module.
*
- * \param func Relay Function
- * \param params parameters
+ * \param func The Relay function.
+ * \param params The parameters.
*/
void BuildRelay(
Function func,
ret_.graph_json = graph_codegen_->GetJSON();
ret_.params = graph_codegen_->GetParams();
- ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_,
- BuildConfig::Current());
+ auto lowered_funcs = graph_codegen_->GetLoweredFunc();
+ if (lowered_funcs.size() != 0) {
+ ret_.mod = tvm::build(
+ lowered_funcs,
+ target_host_,
+ BuildConfig::Current());
+ }
}
protected:
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
std::stringstream err_msg;
err_msg << rang::fg::red;
+ err_msg << " ";
for (auto index : error_indicies) {
err_msg << this->errors_[index].what() << "; ";
}
// First we output a header for the errors.
annotated_prog <<
rang::style::bold << std::endl <<
- "Error(s) have occurred. We have annotated the program with them:"
+ "Error(s) have occurred. The program has been annotated with them:"
<< std::endl << std::endl << rang::style::reset;
// For each global function which contains errors, we will
return RefCreate(n);
}
+TVM_REGISTER_NODE_TYPE(RefCreateNode);
+
TVM_REGISTER_API("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make);
return RefRead(n);
}
+TVM_REGISTER_NODE_TYPE(RefReadNode);
+
TVM_REGISTER_API("relay._make.RefRead")
.set_body_typed(RefReadNode::make);
return RefWrite(n);
}
+TVM_REGISTER_NODE_TYPE(RefWriteNode);
+
TVM_REGISTER_API("relay._make.RefWrite")
.set_body_typed(RefWriteNode::make);
Doc PrintAttr(const NodeRef& value, bool meta = false) {
if (value.defined()) {
Doc printed_attr;
- if (meta) {
+ if (value.as<tvm::ir::Any>()) {
+ printed_attr << "?";
+ } else if (meta) {
printed_attr = meta_.GetMetaNode(value);
} else {
printed_attr = VisitAttr(value);
return doc.str();
}
+std::string PrettyPrint(const NodeRef& node) {
+ Doc doc;
+ doc << PrettyPrinter(false, runtime::TypedPackedFunc<std::string(Expr)>()).PrintFinal(node);
+ return doc.str();
+}
+
std::string AsText(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
p->stream << "RefTypeNode(" << node->value << ")";
});
+TVM_REGISTER_API("relay._make.Any")
+.set_body_typed<IndexExpr()>([]() { return Any::make(); });
+
+
} // namespace relay
} // namespace tvm
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \brief Transform operators.
*/
#include <tvm/relay/op.h>
+#include <tvm/relay/error.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/expr_operator.h>
#include <tvm/ir.h>
const TypeReporter& reporter) {
// types: [data, result]
CHECK_EQ(types.size(), 2);
+ /* If we receive a tuple we can continue, if we receive
+ * anything but an incomplete type we should signal an
+ * error.
+ */
const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) {
- CHECK(types[0].as<IncompleteTypeNode>())
- << "cast: expect input type to be TupleType but get "
- << types[0];
+ throw relay::Error(
+ RELAY_ERROR(
+ "concatenate requires a tuple of tensors as the first argument, found "
+ << PrettyPrint(types[0])));
+ } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
return false;
}
+
const auto* param = attrs.as<ConcatenateAttrs>();
+ if (tensor_tuple->fields[0].as<IncompleteTypeNode>()) {
+ return false;
+ }
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
// Sanity check: ndim and dtype.
const int ndim = static_cast<int>(first->shape.size());
const DataType dtype = first->dtype;
+
for (const Type& ele : tensor_tuple->fields) {
+ if (ele.as<IncompleteTypeNode>()) {
+ return false;
+ }
+
const auto& e = Downcast<TensorType>(ele);
+
int e_ndim = static_cast<int>(e->shape.size());
const DataType& e_dtype = e->dtype;
- CHECK_EQ(e_ndim, ndim) << "relay.concatenate requires all tensors have the same ndim";
- CHECK_EQ(e_dtype, dtype) << "relay.concatenate requires all tensors have the same dtype";
+ if (e_ndim != ndim) {
+ throw relay::Error("relay.concatenate requires all tensors have the same ndim");
+ }
+ if (e_dtype != dtype) {
+ throw relay::Error("relay.concatenate requires all tensors have the same dtype");
+ }
}
// Sanity check: axis
int axis = param->axis;
- CHECK(-ndim <= axis && axis < ndim)
- << "concatenate only accepts `axis` in [-ndim, ndim)"
- << ", but got axis = " << axis
- << ", and ndim = " << ndim;
+ if (!(-ndim <= axis && axis < ndim)) {
+ throw relay::Error(RELAY_ERROR(
+ "concatenate only accepts `axis` in [-ndim, ndim)" <<
+ ", but got axis = " << axis <<
+ ", and ndim = " << ndim));
+ }
axis = axis < 0 ? ndim + axis : axis;
// Calculate shape
std::vector<IndexExpr>&& oshape = AsVector(first->shape);
IndexExpr &concat_dim = oshape[axis];
- for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
- const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
- concat_dim += e->shape[axis];
+ bool has_any = false;
+ if (concat_dim.as<Any>()) {
+ has_any = true;
+ } else {
+ for (int i = 1; i < static_cast<int>(tensor_tuple->fields.size()); ++i) {
+ const auto& e = Downcast<TensorType>(tensor_tuple->fields[i]);
+ if (e->shape[axis].as<Any>()) {
+ has_any = true;
+ break;
+ }
+ concat_dim += e->shape[axis];
+ }
}
- reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype));
+
+ if (has_any) {
+ concat_dim = Any::make();
+ }
+
+ auto rtype = TensorTypeNode::make(oshape, dtype);
+ reporter->Assign(types[1], rtype);
return true;
}
newshape = param->newshape;
}
Array<IndexExpr> oshape;
+ std::unordered_set<size_t> used_input_dims;
+ std::unordered_set<size_t> used_output_dims;
size_t src_idx = 0;
int infer_idx = -1;
} else if (svalue == 0) {
// keep same
CHECK_LT(src_idx, data_shape.size());
+ used_input_dims.insert(src_idx);
+ used_output_dims.insert(oshape.size());
oshape.push_back(data_shape[src_idx++]);
} else if (svalue == -1) {
// inference based on rest
} else if (svalue == -2) {
// copy all remaining dims from source
while (src_idx < data_shape.size()) {
+ used_input_dims.insert(src_idx);
+ used_output_dims.insert(oshape.size());
oshape.push_back(data_shape[src_idx++]);
}
} else if (svalue == -3) {
// merge two dims from source
CHECK_LT(src_idx + 1, data_shape.size());
+ used_input_dims.insert(src_idx);
IndexExpr d1 = data_shape[src_idx++];
+ used_input_dims.insert(src_idx);
IndexExpr d2 = data_shape[src_idx++];
+ used_output_dims.insert(oshape.size());
oshape.push_back(d1 * d2);
} else if (svalue == -4) {
// split the source dim s into two dims
// read the left dim and then the right dim (either can be -1)
CHECK_LT(i + 2, newshape.size());
CHECK_LT(src_idx, data_shape.size());
+ used_input_dims.insert(src_idx);
IndexExpr d0 = data_shape[src_idx++];
Integer d1 = newshape[++i];
Integer d2 = newshape[++i];
if (d1->value == -1) {
CHECK(d2->value != -1)
<< "Split dims cannot both be -1.";
- oshape.push_back(d0 / d2);
+ used_output_dims.insert(oshape.size());
+ if (d0.as<Any>()) {
+ oshape.push_back(Any::make());
+ } else {
+ oshape.push_back(d0 / d2);
+ }
+ used_output_dims.insert(oshape.size());
oshape.push_back(d2);
} else {
+ used_output_dims.insert(oshape.size());
oshape.push_back(d1);
+ used_output_dims.insert(oshape.size());
if (d2->value == -1) {
- oshape.push_back(d0 / d1);
+ if (d0.as<Any>()) {
+ oshape.push_back(Any::make());
+ } else {
+ oshape.push_back(d0 / d1);
+ }
} else {
oshape.push_back(d2);
}
}
if (infer_idx >= 0) {
- IndexExpr new_size = arith::ComputeReduce<tvm::ir::Mul>(oshape, 1);
- IndexExpr old_size = arith::ComputeReduce<tvm::ir::Mul>(data_shape, 1);
- oshape.Set(infer_idx, old_size / new_size);
+ IndexExpr infer_dim = 1;
+ for (size_t i = 0; i < data_shape.size(); ++i) {
+ if (used_input_dims.count(i) != 0) {
+ continue;
+ }
+ if (data_shape[i].as<Any>()) {
+ infer_dim = Any::make();
+ break;
+ }
+ infer_dim *= data_shape[i];
+ }
+ if (!infer_dim.as<Any>()) {
+ for (size_t i = 0; i < oshape.size(); ++i) {
+ if (used_output_dims.count(i) != 0) {
+ continue;
+ }
+ if (oshape[i].as<Any>()) {
+ infer_dim = Any::make();
+ break;
+ }
+ infer_dim /= oshape[i];
+ }
+ }
+ oshape.Set(infer_idx, infer_dim);
}
if (param->reverse) {
// arange operator
TVM_REGISTER_NODE_TYPE(ArangeAttrs);
+double ToScalar(const runtime::NDArray& array) {
+ if (array->dtype.code == kDLInt || array->dtype.code == kDLUInt) {
+ return reinterpret_cast<int32_t*>(array->data)[0];
+ } else {
+ return reinterpret_cast<float*>(array->data)[0];
+ }
+}
+
bool ArangeRel(const Array<Type>& types,
int num_inputs,
- const Attrs& attrs,
+ const Attrs& raw_attrs,
const TypeReporter& reporter) {
- CHECK_EQ(types.size(), 1);
- const ArangeAttrs* param = attrs.as<ArangeAttrs>();
- IndexExpr num_elem = tvm::cast(tvm::Int(32), tvm::ceil(
- tvm::cast(tvm::Float(32), param->stop - param->start) / param->step));
- if (const tvm::ir::IntImm* val = num_elem.as<tvm::ir::IntImm>()) {
- CHECK_GT(val->value, 0)
- << "Invalid arange attributes (start, stop, step): " << param->start
- << ", " << param->stop << ", " << param->step;
- }
- reporter->Assign(types[0], TensorTypeNode::make({num_elem}, param->dtype));
- return true;
+ CHECK_EQ(types.size(), 4);
+ const ArangeAttrs* attrs = raw_attrs.as<ArangeAttrs>();
+ const ConstantNode *cstart, *cstop, *cstep;
+
+ reporter->Assign(types[0], types[1]);
+ reporter->Assign(types[1], types[2]);
+ reporter->Assign(types[2], TensorTypeNode::make({}, attrs->dtype));
+
+ if ((cstart = attrs->start.as<ConstantNode>()) &&
+ (cstop = attrs->stop.as<ConstantNode>()) &&
+ (cstep = attrs->step.as<ConstantNode>())) {
+ double start = ToScalar(cstart->data);
+ double stop = ToScalar(cstop->data);
+ double step = ToScalar(cstep->data);
+ int32_t num_elem = static_cast<int32_t>(std::ceil((stop - start) / step));
+ CHECK_GT(num_elem, 0)
+ << "Invalid arange attributes (start, stop, step): " << attrs->start
+ << ", " << attrs->stop << ", " << attrs->step;
+ reporter->Assign(types[3], TensorTypeNode::make({num_elem}, attrs->dtype));
+ return true;
+ } else {
+ reporter->Assign(types[3], TensorTypeNode::make({Any::make()}, attrs->dtype));
+ return true;
+ }
+}
+
+inline Tensor DynamicArange(const tvm::Tensor& start, const tvm::Tensor& stop,
+ const tvm::Tensor& step, tvm::Type dtype, std::string name = "tensor",
+ std::string tag = topi::kInjective) {
+ tvm::Expr num_elem = tvm::Var("num_elem");
+ return tvm::compute({num_elem}, [&](const Array<tvm::Var>& indices) {
+ return tvm::cast(dtype, start[0] + step[0] * indices[0]);
+ }, name, tag);
}
Array<Tensor> ArangeCompute(const Attrs& attrs,
const Type& out_type,
const Target& target) {
const ArangeAttrs* param = attrs.as<ArangeAttrs>();
- return { topi::arange(param->start, param->stop, param->step, param->dtype) };
+ Tensor start = inputs[0];
+ Tensor stop = inputs[1];
+ Tensor step = inputs[2];
+ Array<tvm::Expr> empty = {0};
+ return { DynamicArange(start, stop, step, param->dtype) };
}
-Expr MakeArange(tvm::Expr start,
- tvm::Expr stop,
- tvm::Expr step,
+Expr MakeArange(Expr start,
+ Expr stop,
+ Expr step,
DataType dtype) {
auto attrs = make_node<ArangeAttrs>();
- attrs->start = std::move(start);
- attrs->stop = std::move(stop);
- attrs->step = std::move(step);
- attrs->dtype = std::move(dtype);
+ attrs->start = start;
+ attrs->stop = stop;
+ attrs->step = step;
+ attrs->dtype = dtype;
static const Op& op = Op::Get("arange");
- return CallNode::make(op, {}, Attrs(attrs), {});
+ return CallNode::make(op, {start, stop, step}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.arange")
.set_body_typed(MakeArange);
+// An issue with the existing design is that we require dependency
+// to type the operator precisely.
+//
+// Supporting this in general is challenging so we duplicate the
+// secondary arguments as args and attributes.
+//
+// In this way reify the arguments at both the value and type level.
+//
+// In the case our arguments are constant we can immediately recover
+// the type of arange.
+//
+// In general I think we should avoid this pattern, and introduce
+// a secondary shape analysis to recover more precise information.
RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ArangeAttrs")
-.set_num_inputs(0)
+.set_num_inputs(3)
.set_support_level(3)
.add_type_rel("Arange", ArangeRel)
.set_attr<FTVMCompute>("FTVMCompute", ArangeCompute)
-.set_attr<TOpPattern>("TOpPattern", kInjective);
+.set_attr<TOpPattern>("TOpPattern", kInjective)
+.set_attr<AnyCodegenStrategy>("AnyCodegenStrategy", kVariableDimensions);
// repeat operator
TVM_REGISTER_NODE_TYPE(RepeatAttrs);
oshape.push_back(s2);
} else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1);
+ } else if (s1.as<Any>() && EqualConstInt(s2, 1)) {
+ // TODO(@jroesch): we need to come back to this
+ oshape.push_back(s2);
+ } else if (s2.as<Any>() && EqualConstInt(s1, 1)) {
+ oshape.push_back(s1);
} else {
RELAY_ERROR(
"Incompatible broadcast type "
Type VisitExpr_(const LetNode* let) final {
// if the definition is a function literal, permit recursion
bool is_functional_literal = let->value.as<FunctionNode>() != nullptr;
+ Type let_type = IncompleteTypeNode::make(Kind::kType);
+
if (is_functional_literal) {
- type_map_[let->var].checked_type = IncompleteTypeNode::make(Kind::kType);
+ let_type = GetType(let->var);
+ type_map_[let->var].checked_type = let_type;
}
- Type vtype = GetType(let->value);
+
if (let->var->type_annotation.defined()) {
- vtype = Unify(vtype, let->var->type_annotation, GetRef<Let>(let));
+ let_type = Unify(let_type, let->var->type_annotation, GetRef<Let>(let));
}
+
+ Type vtype = GetType(let->value);
+ let_type = Unify(let_type, vtype, GetRef<Let>(let));
+
CHECK(is_functional_literal || !type_map_.count(let->var));
// NOTE: no scoping is necessary because var are unique in program
- type_map_[let->var].checked_type = vtype;
+ type_map_[let->var].checked_type = let_type;
return GetType(let->body);
}
}
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
- this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]);
+ this->Unify(fn_ty->arg_types[i], arg_types[i], GetRef<Call>(call));
}
for (auto cs : fn_ty->type_constraints) {
return FuncTypeNode::make(c->inputs, TypeCallNode::make(c->belong_to, types),
td->type_vars, {});
}
+
+ void Solve() {
+ solver_.Solve();
+
+ if (err_reporter.AnyErrors()) {
+ err_reporter.RenderErrors(mod_);
+ }
+ }
};
class TypeInferencer::Resolver : public ExprMutator, PatternMutator {
update_missing_type_annotation_ &&
!new_var->type_annotation.defined());
- bool need_update_fn = (
+ bool need_update_fn =(
std::is_base_of<FunctionNode, T>::value &&
update_missing_type_annotation_ &&
!new_fn->ret_type.defined());
Expr TypeInferencer::Infer(Expr expr) {
- // Step 0: Populate the constraints.
+ // Step 1: Populate the constraints.
GetType(expr);
- // Step 1: Solve the constraints.
- solver_.Solve();
- if (err_reporter.AnyErrors()) {
- err_reporter.RenderErrors(mod_);
- }
+ // Step 2: Solve the constraints.
+ Solve();
- // Step 2: Attach resolved types to checked_type field.
+ // Step 3: Attach resolved types to checked_type field.
auto resolved_expr = Resolver(type_map_, &solver_).VisitExpr(expr);
CHECK(WellFormed(resolved_expr));
return resolved_expr;
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
*/
#include <string>
#include <memory>
+#include <tuple>
+#include <utility>
#include "type_solver.h"
#include "../ir/type_functor.h"
class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
public:
- explicit Unifier(TypeSolver* solver) : solver_(solver) {}
+ explicit Unifier(TypeSolver* solver, const NodeRef& loc) : solver_(solver), loc(loc) {}
Type Unify(const Type& src, const Type& dst) {
// Known limitation
if (lhs->FindRoot() == rhs->FindRoot()) {
return lhs->resolved_type;
}
+
if (lhs->resolved_type.as<IncompleteTypeNode>()) {
- CHECK(!CheckOccurs(lhs, rhs->resolved_type))
+ CHECK(!OccursCheck(lhs, rhs->resolved_type))
<< "Incomplete type " << lhs->resolved_type << " occurs in "
<< rhs->resolved_type << ", cannot unify";
+
solver_->MergeFromTo(lhs, rhs);
return rhs->resolved_type;
} else if (rhs->resolved_type.as<IncompleteTypeNode>()) {
- CHECK(!CheckOccurs(rhs, lhs->resolved_type))
+ CHECK(!OccursCheck(rhs, lhs->resolved_type))
<< "Incomplete type " << rhs->resolved_type << " occurs in "
<< lhs->resolved_type << ", cannot unify";
solver_->MergeFromTo(rhs, lhs);
return lhs->resolved_type;
} else {
Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
- CHECK(resolved.defined())
- << "Unable to unify parent types: "
- << lhs->resolved_type << " and " << rhs->resolved_type;
- TypeNode* top = solver_->GetTypeNode(resolved);
- solver_->MergeFromTo(lhs, top);
- solver_->MergeFromTo(rhs, top);
- return resolved;
+ if (!resolved.defined()) {
+ solver_->ReportError(RELAY_ERROR("unable to unify: "
+ << "`" << PrettyPrint(lhs->resolved_type) << "` and `"
+ << PrettyPrint(rhs->resolved_type) << "`"),
+ this->loc);
+ return lhs->resolved_type;
+ } else {
+ TypeNode* top = solver_->GetTypeNode(resolved);
+ solver_->MergeFromTo(lhs, top);
+ solver_->MergeFromTo(rhs, top);
+ return resolved;
+ }
}
}
// there is a recursive equality constraint, which should be rejected.
// N.b.: A tautology like ?a = ?a is okay and should be checked for
// *before* calling this method
- bool CheckOccurs(TypeNode* lhs, const Type& t) {
+ //
+ // See: https://en.wikipedia.org/wiki/Occurs_check
+ bool OccursCheck(TypeNode* lhs, const Type& t) {
OccursChecker rc(solver_, lhs);
return rc.Check(t);
}
return t1;
}
+ IndexExpr GetShape(const IndexExpr& e) {
+ IndexExpr ex = e;
+ while (true) {
+ auto it = solver_->shape_uf_.find(ex);
+ if (it == solver_->shape_uf_.end()) {
+ return ex;
+ } else {
+ ex = (*it).second;
+ }
+ }
+ }
+
+ IndexExpr UnifyDim(const IndexExpr& lhs, const IndexExpr& rhs) {
+ auto ulhs = GetShape(lhs);
+ auto urhs = GetShape(rhs);
+
+ if (ulhs.same_as(urhs)) {
+ return ulhs;
+ }
+ if (ulhs.as<Any>() || urhs.as<Any>()) {
+ return Any::make();
+ }
+
+ auto left_index0 = ulhs.as<tvm::Variable>();
+ auto right_index0 = urhs.as<tvm::IntImm>();
+ if (left_index0 && right_index0) {
+ solver_->shape_uf_.Set(ulhs, urhs);
+ return urhs;
+ }
+
+ auto left_index1 = ulhs.as<tvm::IntImm>();
+ auto right_index1 = urhs.as<tvm::Variable>();
+ if (left_index1 && right_index1) {
+ solver_->shape_uf_.Set(urhs, ulhs);
+ return ulhs;
+ }
+
+ auto left_index2 = ulhs.as<tvm::IntImm>();
+ auto right_index2 = urhs.as<tvm::IntImm>();
+ if (left_index2 && right_index2 && left_index2->value == right_index2->value) {
+ return ulhs;
+ }
+
+ return tvm::Expr();
+ }
+
+ Type VisitType_(const TensorTypeNode* op, const Type& tn) final {
+ const auto* tt_node = tn.as<TensorTypeNode>();
+ if (!tt_node) {
+ return Type(nullptr);
+ }
+
+ auto tt1 = GetRef<TensorType>(op);
+ auto tt2 = GetRef<TensorType>(tt_node);
+
+ if (AlphaEqual(tt1, tt2)) {
+ return std::move(tt1);
+ }
+
+ if (tt1->dtype != tt2->dtype) {
+ return Type(nullptr);
+ }
+
+ tvm::Array<IndexExpr> shape;
+ if (tt1->shape.size() != tt2->shape.size()) {
+ this->solver_->ReportError(
+ RELAY_ERROR(
+ "tensor type `" << PrettyPrint(tt1) <<
+ "` has " << tt1->shape.size() <<
+ " dimensions, while `" <<
+ PrettyPrint(tt2) <<
+ "` has " << tt2->shape.size() <<
+ " dimensions"), this->loc);
+ return Type(nullptr);
+ }
+
+ std::vector<std::tuple<size_t, IndexExpr, IndexExpr>> mismatches;
+
+ CHECK_EQ(tt1->shape.size(), tt2->shape.size());
+ for (size_t i = 0; i < tt1->shape.size(); i++) {
+ auto dim = UnifyDim(tt1->shape[i], tt2->shape[i]);
+ if (!dim.defined()) {
+ // NB: We push an arbitrary dimension here so we can continue error propogation.
+ shape.push_back(tt1->shape[i]);
+ tvm::Expr shape1 = tt1->shape[i];
+ tvm::Expr shape2 = tt2->shape[i];
+ std::tuple<int, IndexExpr, IndexExpr> tuple = std::make_tuple(i, shape1, shape2);
+ mismatches.push_back(tuple);
+ } else {
+ shape.push_back(dim);
+ }
+ }
+
+ if (mismatches.size() != 0) {
+ RelayErrorStream err;
+ err << "in particular ";
+ for (auto mismatch : mismatches) {
+ err << "dimension "
+ << std::get<0>(mismatch)
+ << " conflicts "
+ << std::get<1>(mismatch)
+ << " does not match "
+ << std::get<2>(mismatch);
+ }
+ Error error(err);
+ this->solver_->ReportError(error, this->loc);
+ return Type(nullptr);
+ }
+
+ return TensorTypeNode::make(shape, tt1->dtype);
+ }
+
Type VisitType_(const TupleTypeNode* op, const Type& tn) final {
const auto* ttn = tn.as<TupleTypeNode>();
if (!ttn || op->fields.size() != ttn->fields.size()) {
private:
TypeSolver* solver_;
+ NodeRef loc;
};
class TypeSolver::Resolver : public TypeMutator {
}
// Add equality constraint
-Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
- // NB(@jroesch): we should probably pass location into the unifier to do better
- // error reporting as well.
- Unifier unifier(this);
+Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef& loc) {
+ Unifier unifier(this, loc);
return unifier.Unify(dst, src);
}
void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
+ CHECK(location.defined());
+ CHECK(current_func.defined());
err_reporter_->ReportAt(current_func, location, err);
}
}
bool TypeSolver::Solve() {
- // Update until queue is empty.
while (!update_queue_.empty()) {
RelationNode* rnode = update_queue_.front();
const auto& rel = rnode->rel;
}
CHECK(rnode->location.defined())
- << "undefined location, should be set when constructing relation node";
+ << "undefined location, should be set when constructing relation node";
// We need to set this in order to understand where unification
// errors generated by the error reporting are coming from.
rnode->resolved = false;
} catch (const dmlc::Error& err) {
rnode->resolved = false;
- this->ReportError(
- RELAY_ERROR(
- "an internal invariant was violated while " \
- "typechecking your program " <<
- err.what()), rnode->location);
+ this->ReportError(RELAY_ERROR("an internal invariant was violated while "
+ "typechecking your program "
+ << err.what()),
+ rnode->location);
}
// Mark inqueue as false after the function call
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
- ErrorReporter err_reporter;
- auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), &err_reporter);
+ ErrorReporter *err_reporter = new ErrorReporter();
+ auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), err_reporter);
- auto mod = [solver](std::string name) -> PackedFunc {
+ auto mod = [solver, err_reporter](std::string name) -> PackedFunc {
if (name == "Solve") {
return TypedPackedFunc<bool()>([solver]() {
return solver->Solve();
});
} else if (name == "Unify") {
- return TypedPackedFunc<Type(Type, Type)>([solver](Type lhs, Type rhs) {
- return solver->Unify(lhs, rhs, lhs);
+ return TypedPackedFunc<Type(Type, Type)>([solver, err_reporter](Type lhs, Type rhs) {
+ auto res = solver->Unify(lhs, rhs, lhs);
+ if (err_reporter->AnyErrors()) {
+ err_reporter->RenderErrors(ModuleNode::make({}, {}), true);
+ }
+ return res;
});
} else if (name == "Resolve") {
return TypedPackedFunc<Type(Type)>([solver](Type t) {
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* \param location The location at which the unification problem arose.
*/
Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location);
-
/*!
* \brief Report an error at the provided location.
* \param err The error to report.
TypeNode* parent{nullptr};
/*! \brief set of relations that is related to this type node */
std::unordered_set<RelationNode*> rel_set;
+
/*!
* \brief Find the root type node, perform path compression
* \return The root type node.
NodeRef location;
};
+ /*! \brief A simple union find between shapes. */
+ tvm::Map<IndexExpr, IndexExpr> shape_uf_;
/*! \brief List of all allocated type nodes */
std::vector<TypeNode*> type_nodes_;
/*! \brief List of all allocated relation nodes */
std::vector<RelationNode*> rel_nodes_;
/*! \brief Number of resolved relations */
size_t num_resolved_rels_{0};
- /*! \brief map from type node to types. */
+ /*! \brief map from types to type nodes. */
std::unordered_map<Type, TypeNode*, NodeHash, NodeEqual> tmap_;
/*! \brief Internal queue to update the relation */
std::queue<RelationNode*> update_queue_;
rel->inqueue = true;
update_queue_.push(rel);
}
+
/*!
* \brief Merge rhs type node to lhs
* \param src The source operand
from_size, from->ctx, to->ctx, from->dtype, stream);
}
+std::vector<int64_t> NDArray::Shape() const {
+ return data_->shape_;
+}
+
} // namespace runtime
} // namespace tvm
--- /dev/null
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+from tvm import relay
+from tvm.relay import Kind, transform
+from tvm.relay.loops import while_loop
+import numpy as np
+
+def infer_type(expr):
+ mod = relay.Module.from_expr(expr)
+ mod = transform.InferType()(mod)
+ entry = mod[mod.entry_func]
+ return entry if isinstance(expr, relay.Function) else entry.body
+
+def int32(val):
+ return relay.const(val, 'int32')
+
+def test_arange_with_dynamic_shape():
+ m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
+ x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
+ y0 = relay.shape_of(x)
+ y1 = relay.take(y0, relay.const(0, 'int32'))
+ y2 = relay.op.arange(y1)
+ ex = relay.create_executor()
+ f = relay.Function([x], y2, type_params=[m, n, k])
+ # TODO(@jroesch): Restore after code generation.
+ # data = np.random.rand(10, 5, 3).astype('float32')
+ # result = ex.evaluate(f)(data)
+ # np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
+
+def test_dynamic_concat():
+ """
+ fn @concat_loop(%i: int32, %st: (any, 1)) -> (any, 1) {
+ if (%i < 10) {
+ let %i = reshape(cast(i, "float32"), newshape=(1, ))
+ let %new_st = concatenate((st, i), axis=0)
+ concat_loop(%i + 1, )
+ } else {
+ st
+ }
+ }
+ """
+ # Initial Values.
+ i = relay.var('i', shape=(), dtype='int32')
+ st = relay.var('st', shape=(relay.Any(), 1), dtype='int32')
+
+ def _cond(i, st):
+ return relay.op.min(relay.op.less(i, int32(10)))
+
+ def _body(i, st):
+ i_vec = relay.op.reshape(i, (1,1))
+ ret = relay.op.concatenate([st, i_vec], axis=0)
+ return i + int32(1), ret
+
+ loop = while_loop(_cond, [i, st], _body)
+ start = relay.var('start', shape=(), dtype='int32')
+ body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
+ func = relay.Function([start], relay.TupleGetItem(body, 1))
+ func = infer_type(func)
+ # TODO(@jroesch, @haichen): We should restore this code when codegeneration
+ # is merged
+ # ret_shape = func.checked_type.ret_type.shape
+ # assert len(ret_shape) == 2, "expected 2-dim output"
+ # assert relay.ir_pass.alpha_eq(ret_shape[0], relay.Any())
+ # import pdb; pdb.set_trace()
+ # mod = relay.module.Module()
+ # print(relay.ir_pass.infer_type(func, mod=mod))
+ # ret = relay.Call(loop, [relay.const(0, 'int32'), init])
+ # mod[mod.entry_func] = relay.Function([], ret)
+ # print(relay.ir_pass.infer_type(mod[mod.entry_func], mod=mod))
+
+ # initial = np.array(0.0, dtype='float32').reshape((1,))
+ # iter_stop = np.array(10, dtype='int32')
+ # ex = relay.create_executor("debug", mod=mod, ctx=tvm.cpu(), target="llvm")
+ # result = ex.evaluate(mod.entry_func)()
+ # np.testing.assert_allclose(result.asnumpy(), np.array(range(10)))
+
+def test_dynamic_concat_with_wrong_annotation():
+ """
+ v0.0.1
+ fn (%start: int32) {
+ %7 = {
+ let %while_loop = fn (%i: int32, %st: Tensor[(1, 1), int32]) {
+ %0 = less(%i, 10)
+ %1 = min(%0)
+ if (%1) {
+ %2 = add(%i, 1)
+ %3 = reshape(%i, newshape=[1, 1])
+ %4 = (%st, %3)
+ /* The result of concat should be 1,1 but it is 2, 1. */
+ %5 = concatenate(%4)
+ %while_loop(%2, %5)
+ } else {
+ (%i, %st)
+ }
+ }
+ %6 = reshape(0, newshape=[1, 1])
+ %while_loop(%start, %6)
+ }
+ %7.1
+ }
+ """
+ # Initial Values.
+ i = relay.var('i', shape=(), dtype='int32')
+ st = relay.var('st', shape=(1, 1), dtype='int32')
+
+ def _cond(i, st):
+ return relay.op.min(relay.op.less(i, int32(10)))
+
+ def _body(i, st):
+ i_vec = relay.op.reshape(i, (1,1))
+ ret = relay.op.concatenate([st, i_vec], axis=0)
+ return i + int32(1), ret
+
+ loop = while_loop(_cond, [i, st], _body)
+ start = relay.var('start', shape=(), dtype='int32')
+ body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
+ func = relay.Function([start], relay.TupleGetItem(body, 1))
+ try:
+ func = infer_type(func)
+ assert False
+ except Exception as e:
+ assert "in particular dimension 0 conflicts 2 does not match 1" in str(e)
+
+if __name__ == "__main__":
+ test_arange_with_dynamic_shape()
+ test_dynamic_concat()
+ test_dynamic_concat_with_wrong_annotation()
def verify_arange(start, stop, step):
dtype = "float32"
if start is None and step is None:
- x = relay.arange(stop)
- ref_res = np.arange(stop)
+ x = relay.arange(relay.const(stop, dtype=dtype))
+ ref_res = np.arange(stop).astype(dtype)
elif start is None:
- x = relay.arange(stop, step=step)
- ref_res = np.arange(stop, step=step)
+ x = relay.arange(relay.const(stop, dtype=dtype), step=relay.const(step, dtype=dtype))
+ ref_res = np.arange(stop, step=step).astype(dtype)
elif step is None:
- x = relay.arange(start, stop)
- ref_res = np.arange(start, stop)
+ x = relay.arange(relay.const(start, dtype=dtype), relay.const(stop, dtype=dtype))
+ ref_res = np.arange(start, stop).astype(dtype)
else:
- x = relay.arange(start, stop, step)
- ref_res = np.arange(start, stop, step)
+ x = relay.arange(
+ relay.const(start, dtype=dtype),
+ relay.const(stop, dtype=dtype),
+ relay.const(step, dtype=dtype))
+ ref_res = np.arange(start, stop, step).astype(dtype)
func = relay.Function([], x)
for target, ctx in ctx_list():
verify_arange(None, 20, 2)
verify_arange(1, 20, None)
verify_arange(1, 20, 2)
- verify_arange(1, 20, 1.5)
+ # arange doesnt' support floating point right now, see type relation
+ # verify_arange(1, 20, 1.5)
verify_arange(1, 20.5, None)
verify_arange(1, 20, 3)
verify_arange(20, 1, -1)
- verify_arange(20, 1, -1.5)
+ # arange doesnt' support floating point right now, see type relation
+ # verify_arange(20, 1, -1.5)
def test_tile():
def verify_tile(dshape, reps):
if __name__ == "__main__":
+ test_arange()
test_cast()
test_zeros_ones()
test_unary_identity()