// or make a more powerful partial evaluator.
class ConstantFolder : public ExprMutator {
public:
- explicit ConstantFolder(FInterpreter executor, IRModule module)
- : executor_(executor),
- module_(module),
+ explicit ConstantFolder(IRModule module)
+ : module_(module),
shape_of_op_(Op::Get("shape_of")),
vm_shape_of_op_(Op::Get("vm.shape_of")),
invoke_tvm_op_(Op::Get("vm.invoke_tvm_op")),
}
private:
- // Internal interepreter.
- FInterpreter executor_;
// Internal constant checker
ConstantChecker checker_;
// Module
const Op& cast_op_;
const Op& ndarray_size_op_;
+ // Create an interpreter.
+ FInterpreter GetInterpreter(const IRModule& mod) {
+ using tvm::transform::PassContext;
+ DLContext ctx;
+ ctx.device_type = kDLCPU;
+ ctx.device_id = 0;
+ Target target = Target::Create("llvm");
+ // use a fresh build context
+ // in case we are already in a build context.
+ With<PassContext> fresh_build_ctx(PassContext::Create());
+
+ return CreateInterpreter(mod, ctx, target);
+ }
+
// Convert value to expression.
Expr ObjectToExpr(const ObjectRef& value) {
if (value->IsInstance<runtime::NDArray::ContainerType>()) {
mod = seq(mod);
auto entry_func = Downcast<Function>(mod->Lookup("main"));
expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
- return ObjectToExpr(executor_(expr));
+
+ FInterpreter executor = GetInterpreter(mod);
+ return ObjectToExpr(executor(expr));
}
// Evaluate a call to the shape_of operator for tensors with constant
};
Expr FoldConstant(const Expr& expr, const IRModule& mod) {
- using tvm::transform::PassContext;
- DLContext ctx;
- ctx.device_type = kDLCPU;
- ctx.device_id = 0;
- Target target = Target::Create("llvm");
- // use a fresh build context
- // in case we are already in a build context.
- With<PassContext> fresh_build_ctx(PassContext::Create());
-
- return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
+ return ConstantFolder(mod).Mutate(expr);
}
namespace transform {