}
};
+/*!
+ * \brief Attributes for VM reshape_tensor operator.
+ */
+struct ReshapeTensorAttrs : public tvm::AttrsNode<ReshapeTensorAttrs> {
+ Array<PrimExpr> newshape;
+
+ TVM_DECLARE_ATTRS(ReshapeTensorAttrs, "relay.attrs.ReshapeTensorAttrs") {
+ TVM_ATTR_FIELD(newshape).describe("The new shape of output tensor");
+ }
+};
+
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VM_H_
Fatal = 15U,
AllocStorage = 16U,
ShapeOf = 17U,
+ ReshapeTensor = 18U,
};
/*! \brief A single virtual machine instruction.
struct /* ShapeOf Operands */ {
RegName tensor;
} shape_of;
+ struct /* ReshapeTensor Operands */ {
+ RegName tensor;
+ RegName newshape;
+ } reshape_tensor;
};
/*!
*/
static Instruction ShapeOf(RegName tensor, RegName dst);
+ /*!
+ * \brief Reshape the tensor given the new shape.
+ * \param tensor The input tensor.
+ * \param newshape The shape tensor.
+ * \param dst The destination to store the output tensor with new shape.
+ * \return The reshape tensor instruction.
+ */
+ static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst);
+
Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
new_fields.append(field)
ret_type = _ty.TupleType(new_fields)
- is_dyn = _ty.type_has_any(call.checked_type)
+ is_dyn = _ty.is_dynamic(call.checked_type)
for arg in call.args:
- is_dyn = is_dyn or _ty.type_has_any(arg.checked_type)
+ is_dyn = is_dyn or _ty.is_dynamic(arg.checked_type)
# check if in the AutoTVM tracing mode, and disable if op is not in wanted list
env = autotvm.task.TaskExtractEnv.current
import tvm.runtime.vm as vm_rt
from tvm import autotvm
from tvm.relay import expr as _expr
-from tvm.relay.ty import type_has_any
+from tvm.relay.ty import is_dynamic
from tvm.relay.backend.interpreter import Executor
from . import _vm
def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs)
ret_type = self.mod["main"].checked_type.ret_type
- if type_has_any(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
+ if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
self.target):
raise ValueError(
"Virtual Machine only supports dynamic graphs on CPU, got output type",
if expr:
self.mod["main"] = expr
ret_type = self.mod["main"].checked_type.ret_type
- if _ty.type_has_any(ret_type):
+ if _ty.is_dynamic(ret_type):
raise ValueError("Graph Runtime only supports static graphs, got output type",
ret_type)
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
The shape function expression.
"""
return _ffi_api.shape_func(func, inputs, outputs, is_inputs)
+
+
+def reshape_tensor(data, shape, newshape):
+ """Invoke the VM ReshapeTensor instruction.
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input data.
+
+ shape : tvm.relay.Expr
+ The newshape tensor.
+
+ newshape : List[tvm.ir.PrimExpr]
+ The new shape.
+ """
+ return _ffi_api.reshape_tensor(data, shape, newshape)
A pass for manifesting explicit memory allocations.
"""
import numpy as np
-from ..expr_functor import ExprMutator
+from ..expr_functor import ExprVisitor, ExprMutator
from ..scope_builder import ScopeBuilder
from . import transform
from .. import op
return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1
+
+class CheckReshapeOnly(ExprVisitor):
+ """A pass to check if the fused op contains only reshape ops."""
+ def __init__(self):
+ super().__init__()
+ self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"),
+ op.get("dyn.reshape")]
+ self.reshape_only = True
+
+ def visit_call(self, call):
+ if not self.reshape_only:
+ return
+ if call.op not in self._reshape_ops:
+ self.reshape_only = False
+ for arg in call.args:
+ self.visit(arg)
+
+
+def is_reshape_only(func):
+ """Check if the primitive function contains only reshape ops."""
+ check = CheckReshapeOnly()
+ check.visit(func)
+ return check.reshape_only
+
+
class ManifestAllocPass(ExprMutator):
"""A pass for explicitly manifesting all memory allocations in Relay."""
self.invoke_tvm = op.vm.invoke_tvm_op
self.shape_func = op.vm.shape_func
self.shape_of = op.vm.shape_of
+ self.reshape_tensor = op.vm.reshape_tensor
self.scopes = [ScopeBuilder()]
self.target_host = target_host
self.default_context = cpu(0)
return scope.get()
- def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
- """Generate the code for invoking a TVM op with a dynamic shape."""
+ def emit_shape_func(self, scope, func, new_args):
+ """Insert the shape function given a primitive function."""
shape_func_ins = []
engine = compile_engine.get()
cfunc = engine.lower_shape_func(func, self.target_host)
expr.Tuple(out_shapes), is_inputs)
scope.let("shape_func", shape_call)
+ return out_shapes
+
+ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
+ """Generate the code for invoking a TVM op with a dynamic shape."""
+ out_shapes = self.emit_shape_func(scope, func, new_args)
storages = []
- for out_shape, out_type in zip(out_shapes, out_types):
+ for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)):
size = self.compute_storage_in_relay(
out_shape, out_type.dtype)
alignment = self.compute_alignment(out_type.dtype)
scope.let("", invoke)
return to_tuple_type(ret_type, tuple_outs.fields)
+ def emit_reshape_tensor(self, scope, func, new_args, ret_type):
+ if self.is_dynamic(ret_type):
+ out_shapes = self.emit_shape_func(scope, func, new_args)
+ shape_expr = out_shapes[0]
+ else:
+ # constant output shape
+ shape = [int(dim) for dim in ret_type.shape]
+ shape_expr = expr.const(shape, dtype=self.compute_dtype)
+ return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape)
+
def is_dynamic(self, ret_type):
- is_dynamic = ty.type_has_any(ret_type)
+ is_dynamic = ty.is_dynamic(ret_type)
# TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
ret_type = call.checked_type
out_types = flatten_tuple_type(ret_type)
+ if is_reshape_only(call.op):
+ # Handle fused op that only contains reshape op
+ return self.emit_reshape_tensor(scope, call.op, new_args, ret_type)
+
if self.is_dynamic(ret_type):
# Handle dynamic case.
return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type)
- else:
- # Handle static case.
- outs = []
- for i, out_ty in enumerate(out_types):
- out = self.make_static_allocation(scope, out_ty, i)
- outs.append(out)
-
- output = expr.Tuple(outs)
- invoke = self.invoke_tvm(call.op, ins, output)
- scope.let("", invoke)
- return to_tuple_type(ret_type, output.fields)
- else:
- return super().visit_call(call)
+
+ # Handle static case.
+ outs = []
+ for i, out_ty in enumerate(out_types):
+ out = self.make_static_allocation(scope, out_ty, i)
+ outs.append(out)
+
+ output = expr.Tuple(outs)
+ invoke = self.invoke_tvm(call.op, ins, output)
+ scope.let("", invoke)
+ return to_tuple_type(ret_type, output.fields)
+ return super().visit_call(call)
@transform.function_pass(opt_level=0)
Any = _ffi_api.Any
-def type_has_any(tensor_type):
- """Check whether type has any as a shape.
+def is_dynamic(tensor_type):
+ """Check whether type has any or symbolic variables as a shape.
tensor_type : Type
The type to be inspected
bool is_dyn{false};
void VisitType_(const TensorTypeNode* tt) {
for (auto dim : tt->shape) {
- if (dim.as<AnyNode>()) {
+ if (dim.as<tir::IntImmNode>() == nullptr) {
is_dyn = true;
break;
}
case Opcode::AllocClosure:
case Opcode::AllocStorage:
case Opcode::ShapeOf:
+ case Opcode::ReshapeTensor:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register_ = instr.dst;
this->VisitExpr(args[0]);
Emit(Instruction::ShapeOf(last_register_, NewRegister()));
})
+ .Match("vm.reshape_tensor",
+ [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+ CHECK_EQ(args.size(), 2u);
+ this->VisitExpr(args[0]);
+ auto tensor_reg = last_register_;
+ this->VisitExpr(args[1]);
+ auto shape_reg = last_register_;
+ Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister()));
+ })
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
infer_dim = indexdiv(infer_dim, oshape[i]);
}
}
+ arith::Analyzer ana;
+ infer_dim = ana.Simplify(infer_dim);
oshape.Set(infer_idx, infer_dim);
}
namespace tvm {
namespace relay {
+// vm.shape_func
TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
RELAY_REGISTER_OP("vm.shape_of")
return {topi::identity(inputs[0])};
});
+// vm.invoke_tvm_op
bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4u);
return {topi::identity(inputs[0])};
});
+// vm.reshape
+TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs);
+
+bool ReshapeTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 3u);
+ auto reshape_attrs = attrs.as<ReshapeTensorAttrs>();
+ CHECK(reshape_attrs);
+ auto tt = types[0].as<TensorTypeNode>();
+ CHECK(tt) << "input must be tensor type";
+ reporter->Assign(types[2], TensorType(reshape_attrs->newshape, tt->dtype));
+ return true;
+}
+
+RELAY_REGISTER_OP("vm.reshape_tensor")
+ .describe(R"code(Use VM reshape_tensor instruction to reshape the tensor.
+)code" TVM_ADD_FILELINE)
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor")
+ .add_argument("shape", "Tensor", "The output shape tensor")
+ .add_type_rel("ReshapeTensor", ReshapeTensorRel)
+ .set_support_level(10)
+ .set_attr<TOpPattern>("TOpPattern", kOpaque)
+ .set_attr<TOpIsStateful>("TOpIsStateful", false)
+ .set_attr<TNonComputational>("TNonComputational", true)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+
+TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor")
+ .set_body_typed([](Expr data, Expr shape, Array<PrimExpr> newshape) {
+ static const Op& op = Op::Get("vm.reshape_tensor");
+ auto attrs = make_object<ReshapeTensorAttrs>();
+ attrs->newshape = std::move(newshape);
+ return Call(op, {data, shape}, Attrs(attrs), {});
+ });
+
} // namespace relay
} // namespace tvm
fields.assign({instr.shape_of.tensor, instr.dst});
break;
}
+ case Opcode::ReshapeTensor: {
+ // Number of fields = 3
+ fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst});
+ break;
+ }
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
}
+ case Opcode::ReshapeTensor: {
+ // Number of fields = 3
+ DCHECK_EQ(instr.fields.size(), 3U);
+ return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]);
+ }
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
case Opcode::ShapeOf:
this->shape_of.tensor = instr.shape_of.tensor;
return;
+ case Opcode::ReshapeTensor:
+ this->reshape_tensor.tensor = instr.reshape_tensor.tensor;
+ this->reshape_tensor.newshape = instr.reshape_tensor.newshape;
+ return;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
case Opcode::LoadConsti:
case Opcode::AllocStorage:
case Opcode::ShapeOf:
+ case Opcode::ReshapeTensor:
case Opcode::Fatal:
return;
case Opcode::AllocTensor:
Instruction Instruction::AllocTensor(RegName storage, RegName offset,
const std::vector<int64_t>& shape, DLDataType dtype,
- Index dst) {
+ RegName dst) {
Instruction instr;
instr.op = Opcode::AllocTensor;
instr.dst = dst;
}
Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName shape_register,
- DLDataType dtype, Index dst) {
+ DLDataType dtype, RegName dst) {
Instruction instr;
instr.op = Opcode::AllocTensorReg;
instr.dst = dst;
}
Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
- Index dst) {
+ RegName dst) {
Instruction instr;
instr.op = Opcode::AllocStorage;
instr.dst = dst;
return instr;
}
-Instruction Instruction::ShapeOf(RegName tensor, Index dst) {
+Instruction Instruction::ShapeOf(RegName tensor, RegName dst) {
Instruction instr;
instr.op = Opcode::ShapeOf;
instr.dst = dst;
return instr;
}
+Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName dst) {
+ Instruction instr;
+ instr.op = Opcode::ReshapeTensor;
+ instr.dst = dst;
+ instr.reshape_tensor.tensor = tensor;
+ instr.reshape_tensor.newshape = newshape;
+ return instr;
+}
+
Instruction Instruction::AllocADT(Index tag, Index num_fields,
- const std::vector<RegName>& datatype_fields, Index dst) {
+ const std::vector<RegName>& datatype_fields, RegName dst) {
Instruction instr;
instr.op = Opcode::AllocADT;
instr.dst = dst;
}
Instruction Instruction::AllocClosure(Index func_index, Index free_vars,
- const std::vector<RegName>& free_var_register, Index dst) {
+ const std::vector<RegName>& free_var_register, RegName dst) {
Instruction instr;
instr.op = Opcode::AllocClosure;
instr.dst = dst;
os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor;
break;
}
+ case Opcode::ReshapeTensor: {
+ os << "reshape_tensor $" << instr.dst << " $" << instr.reshape_tensor.tensor << " $"
+ << instr.reshape_tensor.newshape;
+ break;
+ }
default:
LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
break;
goto main_loop;
}
}
+ case Opcode::ReshapeTensor: {
+ DLContext cpu_ctx;
+ cpu_ctx.device_type = kDLCPU;
+ cpu_ctx.device_id = 0;
+ auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor);
+ NDArray tensor_arr = Downcast<NDArray>(tensor_obj);
+ // Read the shape from shape tensor
+ auto shape_obj = ReadRegister(instr.reshape_tensor.newshape);
+ NDArray shape_tensor = Downcast<NDArray>(CopyTo(shape_obj, cpu_ctx));
+ const DLTensor* dl_tensor = shape_tensor.operator->();
+ CHECK_EQ(dl_tensor->dtype.code, 0u);
+ CHECK_EQ(dl_tensor->dtype.bits, 64);
+ int64_t* dims = reinterpret_cast<int64_t*>(dl_tensor->data);
+ int64_t ndim = shape_tensor->shape[0];
+ std::vector<int64_t> shape(dims, dims + ndim);
+ // Reshape the input tensor
+ auto out_tensor = tensor_arr.CreateView(shape, tensor_arr->dtype);
+ WriteRegister(instr.dst, out_tensor);
+ pc_++;
+ goto main_loop;
+ }
+ default:
+ LOG(FATAL) << "Unknown instruction opcode: " << int(instr.op);
}
}
}
expected_result:
The expected result of running the expression.
"""
+ # TODO(@zhiics, @icemelon9): Disable the gpu test for now until the heterogeneous support
+ # is ready
for target, ctx in ctx_list():
+ if "cuda" in target:
+ continue
vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod)
rts_result = vm.evaluate()(*args)
mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret)
check_result(args, expected, mod=mod)
+def test_vm_reshape_tensor():
+ x_np = np.random.uniform(size=(8, 16)).astype("float32")
+ x = relay.var("x", shape=(8, 16), dtype="float32")
+ y = relay.reshape(x, [-1, 4, 8])
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([x], y)
+ with tvm.transform.PassContext(opt_level=3):
+ exec = relay.vm.compile(mod, "llvm")
+ assert "reshape_tensor" in exec.bytecode
+ check_result([x_np], x_np.reshape([4, 4, 8]), mod)
+
+ x = relay.var("x", shape=(8, 16), dtype="float32")
+ y = relay.reshape(x, [16, -1])
+ y = relay.reverse_reshape(y, [-1, 4, 0])
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([x], y)
+ with tvm.transform.PassContext(opt_level=3):
+ exec = relay.vm.compile(mod, "llvm")
+ assert exec.bytecode.count("reshape_tensor") == 1
+ check_result([x_np], x_np.reshape([4, 4, 8]), mod)
+
+ # reshape with symbolic/any shape
+ for n in [tvm.tir.Any(), tvm.te.size_var('n')]:
+ x = relay.var("x", shape=(n, 16), dtype="float32")
+ y = relay.reshape(x, [-1, 4])
+ y = relay.reshape(y, [0, 2, -1])
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([x], y)
+ with tvm.transform.PassContext(opt_level=3):
+ exec = relay.vm.compile(mod, "llvm")
+ assert exec.bytecode.count("reshape_tensor") == 1
+ check_result([x_np], x_np.reshape([32, 2, 2]), mod)
+
+ # dyn.reshape
+ x = relay.var("x", shape=(8, 16), dtype="float32")
+ y = relay.var("y", shape=(3,), dtype="int32")
+ z = relay.reshape(x, [-1, 4, 8])
+ z = relay.reshape(z, y)
+ mod = tvm.IRModule()
+ mod["main"] = relay.Function([x, y], z)
+ with tvm.transform.PassContext(opt_level=3):
+ exec = relay.vm.compile(mod, "llvm")
+ assert exec.bytecode.count("reshape_tensor") == 2
+ assert "reshape_tensor" in exec.bytecode
+ y_np = np.array([8, 2, 8]).astype("int32")
+ check_result([x_np, y_np], x_np.reshape([8, 2, 8]), mod)
+
if __name__ == "__main__":
pytest.main([__file__])