[FoldConstant] Create Interpreter for each constant subgraph (#6195)
authorAnimesh Jain <anijain@umich.edu>
Mon, 3 Aug 2020 16:36:41 +0000 (09:36 -0700)
committerGitHub <noreply@github.com>
Mon, 3 Aug 2020 16:36:41 +0000 (09:36 -0700)
src/relay/transforms/fold_constant.cc

index b077a8a..0ecbfea 100644 (file)
@@ -77,9 +77,8 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec
 // or make a more powerful partial evaluator.
 class ConstantFolder : public ExprMutator {
  public:
-  explicit ConstantFolder(FInterpreter executor, IRModule module)
-      : executor_(executor),
-        module_(module),
+  explicit ConstantFolder(IRModule module)
+      : module_(module),
         shape_of_op_(Op::Get("shape_of")),
         vm_shape_of_op_(Op::Get("vm.shape_of")),
         invoke_tvm_op_(Op::Get("vm.invoke_tvm_op")),
@@ -163,8 +162,6 @@ class ConstantFolder : public ExprMutator {
   }
 
  private:
-  // Internal interepreter.
-  FInterpreter executor_;
   // Internal constant checker
   ConstantChecker checker_;
   // Module
@@ -180,6 +177,20 @@ class ConstantFolder : public ExprMutator {
   const Op& cast_op_;
   const Op& ndarray_size_op_;
 
+  // Create an interpreter.
+  FInterpreter GetInterpreter(const IRModule& mod) {
+    using tvm::transform::PassContext;
+    DLContext ctx;
+    ctx.device_type = kDLCPU;
+    ctx.device_id = 0;
+    Target target = Target::Create("llvm");
+    // use a fresh build context
+    // in case we are already in a build context.
+    With<PassContext> fresh_build_ctx(PassContext::Create());
+
+    return CreateInterpreter(mod, ctx, target);
+  }
+
   // Convert value to expression.
   Expr ObjectToExpr(const ObjectRef& value) {
     if (value->IsInstance<runtime::NDArray::ContainerType>()) {
@@ -218,7 +229,9 @@ class ConstantFolder : public ExprMutator {
     mod = seq(mod);
     auto entry_func = Downcast<Function>(mod->Lookup("main"));
     expr = expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
-    return ObjectToExpr(executor_(expr));
+
+    FInterpreter executor = GetInterpreter(mod);
+    return ObjectToExpr(executor(expr));
   }
 
   // Evaluate a call to the shape_of operator for tensors with constant
@@ -331,16 +344,7 @@ class ConstantFolder : public ExprMutator {
 };
 
 Expr FoldConstant(const Expr& expr, const IRModule& mod) {
-  using tvm::transform::PassContext;
-  DLContext ctx;
-  ctx.device_type = kDLCPU;
-  ctx.device_id = 0;
-  Target target = Target::Create("llvm");
-  // use a fresh build context
-  // in case we are already in a build context.
-  With<PassContext> fresh_build_ctx(PassContext::Create());
-
-  return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr);
+  return ConstantFolder(mod).Mutate(expr);
 }
 
 namespace transform {