[Relay][VM] Add ReshapeTensor instruction in the VM to replace the reshape op (#6089)
authorHaichen Shen <shenhaichen@gmail.com>
Tue, 21 Jul 2020 17:10:16 +0000 (10:10 -0700)
committerGitHub <noreply@github.com>
Tue, 21 Jul 2020 17:10:16 +0000 (10:10 -0700)
* [VM] Add reshape tensor instruction

* update

* lint

* fix

* fix

15 files changed:
include/tvm/relay/attrs/vm.h
include/tvm/runtime/vm.h
python/tvm/relay/backend/compile_engine.py
python/tvm/relay/backend/vm.py
python/tvm/relay/build_module.py
python/tvm/relay/op/vm/vm.py
python/tvm/relay/transform/memory_alloc.py
python/tvm/relay/ty.py
src/relay/analysis/util.cc
src/relay/backend/vm/compiler.cc
src/relay/op/tensor/transform.cc
src/relay/op/vm/vm.cc
src/runtime/vm/executable.cc
src/runtime/vm/vm.cc
tests/python/relay/test_vm.py

index 9144f47..7eb1008 100644 (file)
@@ -42,6 +42,17 @@ struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
   }
 };
 
+/*!
+ * \brief Attributes for VM reshape_tensor operator.
+ */
+struct ReshapeTensorAttrs : public tvm::AttrsNode<ReshapeTensorAttrs> {
+  Array<PrimExpr> newshape;
+
+  TVM_DECLARE_ATTRS(ReshapeTensorAttrs, "relay.attrs.ReshapeTensorAttrs") {
+    TVM_ATTR_FIELD(newshape).describe("The new shape of output tensor");
+  }
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_VM_H_
index 0cce533..cb98715 100644 (file)
@@ -115,6 +115,7 @@ enum class Opcode {
   Fatal = 15U,
   AllocStorage = 16U,
   ShapeOf = 17U,
+  ReshapeTensor = 18U,
 };
 
 /*! \brief A single virtual machine instruction.
@@ -249,6 +250,10 @@ struct Instruction {
     struct /* ShapeOf Operands */ {
       RegName tensor;
     } shape_of;
+    struct /* ReshapeTensor Operands */ {
+      RegName tensor;
+      RegName newshape;
+    } reshape_tensor;
   };
 
   /*!
@@ -401,6 +406,15 @@ struct Instruction {
    */
   static Instruction ShapeOf(RegName tensor, RegName dst);
 
+  /*!
+   * \brief Reshape the tensor given the new shape.
+   * \param tensor The input tensor.
+   * \param newshape The shape tensor.
+   * \param dst The destination to store the output tensor with new shape.
+   * \return The reshape tensor instruction.
+   */
+  static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst);
+
   Instruction();
   Instruction(const Instruction& instr);
   Instruction& operator=(const Instruction& instr);
index 8e6698e..25c75b1 100644 (file)
@@ -246,9 +246,9 @@ def lower_call(call, inputs, target):
                 new_fields.append(field)
         ret_type = _ty.TupleType(new_fields)
 
-    is_dyn = _ty.type_has_any(call.checked_type)
+    is_dyn = _ty.is_dynamic(call.checked_type)
     for arg in call.args:
-        is_dyn = is_dyn or _ty.type_has_any(arg.checked_type)
+        is_dyn = is_dyn or _ty.is_dynamic(arg.checked_type)
 
     # check if in the AutoTVM tracing mode, and disable if op is not in wanted list
     env = autotvm.task.TaskExtractEnv.current
index 75a11b3..16d4724 100644 (file)
@@ -27,7 +27,7 @@ import tvm.runtime.ndarray as _nd
 import tvm.runtime.vm as vm_rt
 from tvm import autotvm
 from tvm.relay import expr as _expr
-from tvm.relay.ty import type_has_any
+from tvm.relay.ty import is_dynamic
 from tvm.relay.backend.interpreter import Executor
 from . import _vm
 
@@ -257,7 +257,7 @@ class VMExecutor(Executor):
         def _vm_wrapper(*args, **kwargs):
             args = self._convert_args(main, args, kwargs)
             ret_type = self.mod["main"].checked_type.ret_type
-            if type_has_any(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
+            if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
                     self.target):
                 raise ValueError(
                     "Virtual Machine only supports dynamic graphs on CPU, got output type",
index 4ffabf4..2f285ef 100644 (file)
@@ -359,7 +359,7 @@ class GraphExecutor(_interpreter.Executor):
         if expr:
             self.mod["main"] = expr
         ret_type = self.mod["main"].checked_type.ret_type
-        if _ty.type_has_any(ret_type):
+        if _ty.is_dynamic(ret_type):
             raise ValueError("Graph Runtime only supports static graphs, got output type",
                              ret_type)
         num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
index 761188a..0fb7ace 100644 (file)
@@ -81,3 +81,20 @@ def shape_func(func, inputs, outputs, is_inputs):
         The shape function expression.
     """
     return _ffi_api.shape_func(func, inputs, outputs, is_inputs)
+
+
+def reshape_tensor(data, shape, newshape):
+    """Invoke the VM ReshapeTensor instruction.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The input data.
+
+    shape : tvm.relay.Expr
+        The newshape tensor.
+
+    newshape : List[tvm.ir.PrimExpr]
+        The new shape.
+    """
+    return _ffi_api.reshape_tensor(data, shape, newshape)
index 805905c..ae7db33 100644 (file)
@@ -19,7 +19,7 @@
 A pass for manifesting explicit memory allocations.
 """
 import numpy as np
-from ..expr_functor import ExprMutator
+from ..expr_functor import ExprVisitor, ExprMutator
 from ..scope_builder import ScopeBuilder
 from . import transform
 from .. import op
@@ -38,6 +38,31 @@ def is_primitive(call):
     return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
            hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1
 
+
+class CheckReshapeOnly(ExprVisitor):
+    """A pass to check if the fused op contains only reshape ops."""
+    def __init__(self):
+        super().__init__()
+        self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"),
+                             op.get("dyn.reshape")]
+        self.reshape_only = True
+
+    def visit_call(self, call):
+        if not self.reshape_only:
+            return
+        if call.op not in self._reshape_ops:
+            self.reshape_only = False
+        for arg in call.args:
+            self.visit(arg)
+
+
+def is_reshape_only(func):
+    """Check if the primitive function contains only reshape ops."""
+    check = CheckReshapeOnly()
+    check.visit(func)
+    return check.reshape_only
+
+
 class ManifestAllocPass(ExprMutator):
     """A pass for explicitly manifesting all memory allocations in Relay."""
 
@@ -45,6 +70,7 @@ class ManifestAllocPass(ExprMutator):
         self.invoke_tvm = op.vm.invoke_tvm_op
         self.shape_func = op.vm.shape_func
         self.shape_of = op.vm.shape_of
+        self.reshape_tensor = op.vm.reshape_tensor
         self.scopes = [ScopeBuilder()]
         self.target_host = target_host
         self.default_context = cpu(0)
@@ -121,8 +147,8 @@ class ManifestAllocPass(ExprMutator):
 
         return scope.get()
 
-    def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
-        """Generate the code for invoking a TVM op with a dynamic shape."""
+    def emit_shape_func(self, scope, func, new_args):
+        """Insert the shape function given a primitive function."""
         shape_func_ins = []
         engine = compile_engine.get()
         cfunc = engine.lower_shape_func(func, self.target_host)
@@ -165,9 +191,14 @@ class ManifestAllocPass(ExprMutator):
             expr.Tuple(out_shapes), is_inputs)
 
         scope.let("shape_func", shape_call)
+        return out_shapes
+
+    def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
+        """Generate the code for invoking a TVM op with a dynamic shape."""
+        out_shapes = self.emit_shape_func(scope, func, new_args)
 
         storages = []
-        for out_shape, out_type in zip(out_shapes, out_types):
+        for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)):
             size = self.compute_storage_in_relay(
                 out_shape, out_type.dtype)
             alignment = self.compute_alignment(out_type.dtype)
@@ -191,8 +222,18 @@ class ManifestAllocPass(ExprMutator):
         scope.let("", invoke)
         return to_tuple_type(ret_type, tuple_outs.fields)
 
+    def emit_reshape_tensor(self, scope, func, new_args, ret_type):
+        if self.is_dynamic(ret_type):
+            out_shapes = self.emit_shape_func(scope, func, new_args)
+            shape_expr = out_shapes[0]
+        else:
+            # constant output shape
+            shape = [int(dim) for dim in ret_type.shape]
+            shape_expr = expr.const(shape, dtype=self.compute_dtype)
+        return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape)
+
     def is_dynamic(self, ret_type):
-        is_dynamic = ty.type_has_any(ret_type)
+        is_dynamic = ty.is_dynamic(ret_type)
         # TODO(@jroesch): restore this code, more complex then it seems
         # for arg in call.args:
         #     is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
@@ -208,22 +249,25 @@ class ManifestAllocPass(ExprMutator):
             ret_type = call.checked_type
             out_types = flatten_tuple_type(ret_type)
 
+            if is_reshape_only(call.op):
+                # Handle fused op that only contains reshape op
+                return self.emit_reshape_tensor(scope, call.op, new_args, ret_type)
+
             if self.is_dynamic(ret_type):
                 # Handle dynamic case.
                 return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type)
-            else:
-                # Handle static case.
-                outs = []
-                for i, out_ty in enumerate(out_types):
-                    out = self.make_static_allocation(scope, out_ty, i)
-                    outs.append(out)
-
-                output = expr.Tuple(outs)
-                invoke = self.invoke_tvm(call.op, ins, output)
-                scope.let("", invoke)
-                return to_tuple_type(ret_type, output.fields)
-        else:
-            return super().visit_call(call)
+
+            # Handle static case.
+            outs = []
+            for i, out_ty in enumerate(out_types):
+                out = self.make_static_allocation(scope, out_ty, i)
+                outs.append(out)
+
+            output = expr.Tuple(outs)
+            invoke = self.invoke_tvm(call.op, ins, output)
+            scope.let("", invoke)
+            return to_tuple_type(ret_type, output.fields)
+        return super().visit_call(call)
 
 
 @transform.function_pass(opt_level=0)
index 19cc10a..84bd1ee 100644 (file)
@@ -25,8 +25,8 @@ from . import _ffi_api
 
 Any = _ffi_api.Any
 
-def type_has_any(tensor_type):
-    """Check whether type has any as a shape.
+def is_dynamic(tensor_type):
+    """Check whether type has any or symbolic variables as a shape.
 
     tensor_type : Type
         The type to be inspected
index c8dbb49..b1c5124 100644 (file)
@@ -424,7 +424,7 @@ struct IsDynamicVisitor : public TypeVisitor {
   bool is_dyn{false};
   void VisitType_(const TensorTypeNode* tt) {
     for (auto dim : tt->shape) {
-      if (dim.as<AnyNode>()) {
+      if (dim.as<tir::IntImmNode>() == nullptr) {
         is_dyn = true;
         break;
       }
index 585b803..ab11c6c 100644 (file)
@@ -284,6 +284,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
       case Opcode::AllocClosure:
       case Opcode::AllocStorage:
       case Opcode::ShapeOf:
+      case Opcode::ReshapeTensor:
       case Opcode::Move:
       case Opcode::InvokeClosure:
         last_register_ = instr.dst;
@@ -601,6 +602,15 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
                    this->VisitExpr(args[0]);
                    Emit(Instruction::ShapeOf(last_register_, NewRegister()));
                  })
+          .Match("vm.reshape_tensor",
+                 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
+                   CHECK_EQ(args.size(), 2u);
+                   this->VisitExpr(args[0]);
+                   auto tensor_reg = last_register_;
+                   this->VisitExpr(args[1]);
+                   auto shape_reg = last_register_;
+                   Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister()));
+                 })
           .Match("memory.kill",
                  [](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
                    LOG(FATAL) << "memory.kill is not yet supported";
index 1b07253..7ebca66 100644 (file)
@@ -576,6 +576,8 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
         infer_dim = indexdiv(infer_dim, oshape[i]);
       }
     }
+    arith::Analyzer ana;
+    infer_dim = ana.Simplify(infer_dim);
     oshape.Set(infer_idx, infer_dim);
   }
 
index ffe276e..6e611d6 100644 (file)
@@ -37,6 +37,7 @@
 namespace tvm {
 namespace relay {
 
+// vm.shape_func
 TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
 
 RELAY_REGISTER_OP("vm.shape_of")
@@ -133,6 +134,7 @@ RELAY_REGISTER_OP("vm.shape_func")
                              return {topi::identity(inputs[0])};
                            });
 
+// vm.invoke_tvm_op
 bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
                     const TypeReporter& reporter) {
   CHECK_EQ(types.size(), 4u);
@@ -181,5 +183,40 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op")
                              return {topi::identity(inputs[0])};
                            });
 
+// vm.reshape
+TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs);
+
+bool ReshapeTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+                      const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3u);
+  auto reshape_attrs = attrs.as<ReshapeTensorAttrs>();
+  CHECK(reshape_attrs);
+  auto tt = types[0].as<TensorTypeNode>();
+  CHECK(tt) << "input must be tensor type";
+  reporter->Assign(types[2], TensorType(reshape_attrs->newshape, tt->dtype));
+  return true;
+}
+
+RELAY_REGISTER_OP("vm.reshape_tensor")
+    .describe(R"code(Use VM reshape_tensor instruction to reshape the tensor.
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor")
+    .add_argument("shape", "Tensor", "The output shape tensor")
+    .add_type_rel("ReshapeTensor", ReshapeTensorRel)
+    .set_support_level(10)
+    .set_attr<TOpPattern>("TOpPattern", kOpaque)
+    .set_attr<TOpIsStateful>("TOpIsStateful", false)
+    .set_attr<TNonComputational>("TNonComputational", true)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);
+
+TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor")
+    .set_body_typed([](Expr data, Expr shape, Array<PrimExpr> newshape) {
+      static const Op& op = Op::Get("vm.reshape_tensor");
+      auto attrs = make_object<ReshapeTensorAttrs>();
+      attrs->newshape = std::move(newshape);
+      return Call(op, {data, shape}, Attrs(attrs), {});
+    });
+
 }  // namespace relay
 }  // namespace tvm
index f520404..4944778 100644 (file)
@@ -422,6 +422,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
       fields.assign({instr.shape_of.tensor, instr.dst});
       break;
     }
+    case Opcode::ReshapeTensor: {
+      // Number of fields = 3
+      fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst});
+      break;
+    }
     default:
       LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
       break;
@@ -693,6 +698,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
       DCHECK_EQ(instr.fields.size(), 2U);
       return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
     }
+    case Opcode::ReshapeTensor: {
+      // Number of fields = 3
+      DCHECK_EQ(instr.fields.size(), 3U);
+      return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]);
+    }
     default:
       LOG(FATAL) << "Invalid opcode" << instr.opcode;
       return Instruction();
index 6b10a89..24fc110 100644 (file)
@@ -148,6 +148,10 @@ Instruction::Instruction(const Instruction& instr) {
     case Opcode::ShapeOf:
       this->shape_of.tensor = instr.shape_of.tensor;
       return;
+    case Opcode::ReshapeTensor:
+      this->reshape_tensor.tensor = instr.reshape_tensor.tensor;
+      this->reshape_tensor.newshape = instr.reshape_tensor.newshape;
+      return;
     default:
       std::ostringstream out;
       out << "Invalid instruction " << static_cast<int>(instr.op);
@@ -265,6 +269,7 @@ Instruction::~Instruction() {
     case Opcode::LoadConsti:
     case Opcode::AllocStorage:
     case Opcode::ShapeOf:
+    case Opcode::ReshapeTensor:
     case Opcode::Fatal:
       return;
     case Opcode::AllocTensor:
@@ -320,7 +325,7 @@ Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index out
 
 Instruction Instruction::AllocTensor(RegName storage, RegName offset,
                                      const std::vector<int64_t>& shape, DLDataType dtype,
-                                     Index dst) {
+                                     RegName dst) {
   Instruction instr;
   instr.op = Opcode::AllocTensor;
   instr.dst = dst;
@@ -336,7 +341,7 @@ Instruction Instruction::AllocTensor(RegName storage, RegName offset,
 }
 
 Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName shape_register,
-                                        DLDataType dtype, Index dst) {
+                                        DLDataType dtype, RegName dst) {
   Instruction instr;
   instr.op = Opcode::AllocTensorReg;
   instr.dst = dst;
@@ -348,7 +353,7 @@ Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName
 }
 
 Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
-                                      Index dst) {
+                                      RegName dst) {
   Instruction instr;
   instr.op = Opcode::AllocStorage;
   instr.dst = dst;
@@ -358,7 +363,7 @@ Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType
   return instr;
 }
 
-Instruction Instruction::ShapeOf(RegName tensor, Index dst) {
+Instruction Instruction::ShapeOf(RegName tensor, RegName dst) {
   Instruction instr;
   instr.op = Opcode::ShapeOf;
   instr.dst = dst;
@@ -366,8 +371,17 @@ Instruction Instruction::ShapeOf(RegName tensor, Index dst) {
   return instr;
 }
 
+Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName dst) {
+  Instruction instr;
+  instr.op = Opcode::ReshapeTensor;
+  instr.dst = dst;
+  instr.reshape_tensor.tensor = tensor;
+  instr.reshape_tensor.newshape = newshape;
+  return instr;
+}
+
 Instruction Instruction::AllocADT(Index tag, Index num_fields,
-                                  const std::vector<RegName>& datatype_fields, Index dst) {
+                                  const std::vector<RegName>& datatype_fields, RegName dst) {
   Instruction instr;
   instr.op = Opcode::AllocADT;
   instr.dst = dst;
@@ -381,7 +395,7 @@ Instruction Instruction::AllocADT(Index tag, Index num_fields,
 }
 
 Instruction Instruction::AllocClosure(Index func_index, Index free_vars,
-                                      const std::vector<RegName>& free_var_register, Index dst) {
+                                      const std::vector<RegName>& free_var_register, RegName dst) {
   Instruction instr;
   instr.op = Opcode::AllocClosure;
   instr.dst = dst;
@@ -604,6 +618,11 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
       os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor;
       break;
     }
+    case Opcode::ReshapeTensor: {
+      os << "reshape_tensor $" << instr.dst << " $" << instr.reshape_tensor.tensor << " $"
+         << instr.reshape_tensor.newshape;
+      break;
+    }
     default:
       LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
       break;
@@ -1103,6 +1122,29 @@ void VirtualMachine::RunLoop() {
           goto main_loop;
         }
       }
+      case Opcode::ReshapeTensor: {
+        DLContext cpu_ctx;
+        cpu_ctx.device_type = kDLCPU;
+        cpu_ctx.device_id = 0;
+        auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor);
+        NDArray tensor_arr = Downcast<NDArray>(tensor_obj);
+        // Read the shape from shape tensor
+        auto shape_obj = ReadRegister(instr.reshape_tensor.newshape);
+        NDArray shape_tensor = Downcast<NDArray>(CopyTo(shape_obj, cpu_ctx));
+        const DLTensor* dl_tensor = shape_tensor.operator->();
+        CHECK_EQ(dl_tensor->dtype.code, 0u);
+        CHECK_EQ(dl_tensor->dtype.bits, 64);
+        int64_t* dims = reinterpret_cast<int64_t*>(dl_tensor->data);
+        int64_t ndim = shape_tensor->shape[0];
+        std::vector<int64_t> shape(dims, dims + ndim);
+        // Reshape the input tensor
+        auto out_tensor = tensor_arr.CreateView(shape, tensor_arr->dtype);
+        WriteRegister(instr.dst, out_tensor);
+        pc_++;
+        goto main_loop;
+      }
+      default:
+        LOG(FATAL) << "Unknown instruction opcode: " << int(instr.op);
     }
   }
 }
index f2b15ec..91214cb 100644 (file)
@@ -39,7 +39,11 @@ def check_result(args, expected_result, mod=None):
     expected_result:
         The expected result of running the expression.
     """
+    # TODO(@zhiics, @icemelon9): Disable the gpu test for now until the heterogeneous support
+    #   is ready
     for target, ctx in ctx_list():
+        if "cuda" in target:
+            continue
         vm = relay.create_executor('vm', ctx=ctx, target=target, mod=mod)
 
         rts_result = vm.evaluate()(*args)
@@ -622,5 +626,52 @@ def test_loop_free_var():
         mod["main"] = relay.Function(relay.analysis.free_vars(ret), ret)
         check_result(args, expected, mod=mod)
 
+def test_vm_reshape_tensor():
+    x_np = np.random.uniform(size=(8, 16)).astype("float32")
+    x = relay.var("x", shape=(8, 16), dtype="float32")
+    y = relay.reshape(x, [-1, 4, 8])
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function([x], y)
+    with tvm.transform.PassContext(opt_level=3):
+        exec = relay.vm.compile(mod, "llvm")
+    assert "reshape_tensor" in exec.bytecode
+    check_result([x_np], x_np.reshape([4, 4, 8]), mod)
+
+    x = relay.var("x", shape=(8, 16), dtype="float32")
+    y = relay.reshape(x, [16, -1])
+    y = relay.reverse_reshape(y, [-1, 4, 0])
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function([x], y)
+    with tvm.transform.PassContext(opt_level=3):
+        exec = relay.vm.compile(mod, "llvm")
+    assert exec.bytecode.count("reshape_tensor") == 1
+    check_result([x_np], x_np.reshape([4, 4, 8]), mod)
+
+    # reshape with symbolic/any shape
+    for n in [tvm.tir.Any(), tvm.te.size_var('n')]:
+        x = relay.var("x", shape=(n, 16), dtype="float32")
+        y = relay.reshape(x, [-1, 4])
+        y = relay.reshape(y, [0, 2, -1])
+        mod = tvm.IRModule()
+        mod["main"] = relay.Function([x], y)
+        with tvm.transform.PassContext(opt_level=3):
+            exec = relay.vm.compile(mod, "llvm")
+        assert exec.bytecode.count("reshape_tensor") == 1
+        check_result([x_np], x_np.reshape([32, 2, 2]), mod)
+
+    # dyn.reshape
+    x = relay.var("x", shape=(8, 16), dtype="float32")
+    y = relay.var("y", shape=(3,), dtype="int32")
+    z = relay.reshape(x, [-1, 4, 8])
+    z = relay.reshape(z, y)
+    mod = tvm.IRModule()
+    mod["main"] = relay.Function([x, y], z)
+    with tvm.transform.PassContext(opt_level=3):
+        exec = relay.vm.compile(mod, "llvm")
+    assert exec.bytecode.count("reshape_tensor") == 2
+    assert "reshape_tensor" in exec.bytecode
+    y_np = np.array([8, 2, 8]).astype("int32")
+    check_result([x_np, y_np], x_np.reshape([8, 2, 8]), mod)
+
 if __name__ == "__main__":
     pytest.main([__file__])