[Relay][Training] Add checkpoint annotation for checkpointing memory optimization...
authorAltan Haan <altanh@cs.washington.edu>
Sun, 27 Oct 2019 00:04:42 +0000 (17:04 -0700)
committerJared Roesch <roeschinc@gmail.com>
Sun, 27 Oct 2019 00:04:42 +0000 (17:04 -0700)
* add checkpoint annotation for checkpointing memory optimization

* add alpha-equivalence checkpoint test and fix gradient type issue

* fix build issues

* ignore checkpoint annotation when checking missing gradients

* refactor, fix checkpoint compute for tuple and add tests

python/tvm/relay/op/annotation/annotation.py
src/relay/op/annotation/annotation.cc
src/relay/pass/de_duplicate.cc
src/relay/pass/gradient.cc
tests/python/relay/test_op_grad_level10.py
tests/python/relay/test_op_level10.py

index 10c8985..2b9d4bc 100644 (file)
 """Annotation operations."""
 from __future__ import absolute_import as _abs
 from . import _make
+from ..op import register_schedule, schedule_injective
 from .... import nd as _nd
 from .... import TVMContext as _TVMContext
 
-
 def on_device(data, device):
     """Annotate an expression with a certain device type.
 
@@ -61,3 +61,20 @@ def stop_fusion(data):
         The annotated expression.
     """
     return _make.stop_fusion(data)
+
+def checkpoint(data):
+    """Annotate an expression to be a checkpoint for the checkpointing memory optimization.
+
+    Parameters
+    ----------
+    data : tvm.relay.Expr
+        The expression to be annotated.
+
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The annotated expression.
+    """
+    return _make.checkpoint(data)
+
+register_schedule("annotation.checkpoint", schedule_injective)
index eeacc6c..5a8ad33 100644 (file)
@@ -144,5 +144,32 @@ Mark the end of bitpacking.
                          return {topi::identity(inputs[0])};
                        });
 
+TVM_REGISTER_API("relay.op.annotation._make.checkpoint")
+.set_body_typed<Expr(Expr)>([](Expr data) {
+  static const Op& op = Op::Get("annotation.checkpoint");
+  return CallNode::make(op, {data}, Attrs{}, {});
+});
+
+RELAY_REGISTER_OP("annotation.checkpoint")
+.describe(R"code(
+Mark a checkpoint for checkpointing memory optimization.
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.set_support_level(10)
+.add_type_rel("Identity", IdentityRel)
+.set_attr<TOpPattern>("TOpPattern", kOpaque)
+.set_attr<TOpIsStateful>("TOpIsStateful", false)
+.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
+                               ElemwiseArbitraryLayout)
+.set_attr<FTVMCompute>("FTVMCompute",
+                       [](const Attrs& attrs, const Array<Tensor>& inputs,
+                          const Type& out_dtype, const Target& target) -> Array<Tensor> {
+                         Array<Tensor> outputs;
+                         for (size_t i = 0; i < inputs.size(); ++i) {
+                           outputs.push_back(topi::identity(inputs[i]));
+                         }
+                         return outputs;
+                       });
+
 }  // namespace relay
 }  // namespace tvm
index 332803c..38acdcd 100644 (file)
@@ -52,7 +52,9 @@ Expr DeDup(const Expr& e) {
     }
 
     Expr VisitExpr(const Expr& e) final {
-      return ExprMutator::VisitExpr(e);
+      auto ret = ExprMutator::VisitExpr(e);
+      ret->checked_type_ = e->checked_type_;
+      return ret;
     }
 
     Expr VisitExpr_(const VarNode* op) final {
index 8b06b87..b93c110 100644 (file)
@@ -273,24 +273,29 @@ Type ReverseType(const Type& t) {
  * by doing a structure preserving map.
  */
 Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
-                const Type& t,
+                const std::function<Type(const Type&)>& tf,
+                const Type& forward_type,
                 const Expr& e,
                 LetList* ll) {
   CHECK(IsAtomic(e)) << e;
-  if (t.as<TensorTypeNode>()) {
+  if (forward_type.as<TensorTypeNode>()) {
     auto ret = f(e);
-    ret->checked_type_ = t;
+    ret->checked_type_ = tf(forward_type);
     return ret;
-  } else if (auto* tt = t.as<TupleTypeNode>()) {
+  } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
     tvm::Array<Expr> fields;
+    tvm::Array<Type> types;
     for (size_t i = 0; i < tt->fields.size(); ++i) {
-      fields.push_back(LiftTensor(f,
-                                  tt->fields[i],
-                                  ll->Push(GetField(e, i)),
-                                  ll));
+      auto field = LiftTensor(f,
+                              tf,
+                              tt->fields[i],
+                              ll->Push(GetField(e, i)),
+                              ll);
+      fields.push_back(field);
+      types.push_back(field->checked_type_);
     }
     auto ret = TupleNode::make(fields);
-    ret->checked_type_ = t;
+    ret->checked_type_ = TupleTypeNode::make(types);
     return std::move(ret);
   } else {
     LOG(FATAL) << "unsupported input/output type: " << tt;
@@ -298,25 +303,63 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
   }
 }
 
+/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
+ * by stitching the references in the AD values.
+ */
+void TransferGrads(const Type& forward_type,
+                   const Expr& from,
+                   const Expr& to,
+                   LetList* ll) {
+  CHECK(IsAtomic(from)) << from;
+  CHECK(IsAtomic(to)) << to;
+  if (forward_type.as<TensorTypeNode>()) {
+    auto from_ref = TupleGetItemNode::make(from, 1);
+    auto to_ref = TupleGetItemNode::make(to, 1);
+    ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref)));
+  } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
+    for (size_t i = 0; i < tt->fields.size(); ++i) {
+      TransferGrads(tt->fields[i],
+                    ll->Push(TupleGetItemNode::make(from, i)),
+                    ll->Push(TupleGetItemNode::make(to, i)),
+                    ll);
+    }
+  } else {
+    LOG(FATAL) << "Unsupported input/output type: " << forward_type;
+    throw;
+  }
+}
+
 /*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
-Expr GetRev(const Type& t, const Expr& e, LetList* ll) {
+Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
   auto rev = [&](const Expr& e) {
     return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
   };
-  return LiftTensor(rev, t, e, ll);
+  auto rev_type = [&](const Type& forward_type) {
+    return ReverseType(forward_type);
+  };
+  return LiftTensor(rev, rev_type, forward_type, e, ll);
 }
 
 /*! \brief ReverseType(t) -> t. Get the original value. */
-Expr GetValue(const Type& t, const Expr& e, LetList* ll) {
-  return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll);
+Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
+  auto val = [&](const Expr& e) {
+    return GetField(e, 0);
+  };
+  auto val_type = [&](const Type& forward_type) {
+    return forward_type;
+  };
+  return LiftTensor(val, val_type, forward_type, e, ll);
 }
 
 /*! \brief ReverseType(t) -> t. Get the gradient. */
-Expr GetGrad(const Type& t, const Expr& e, LetList* ll) {
+Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
   auto grad = [&](const Expr& e) {
     return ll->Push(RefReadNode::make(GetField(e, 1)));
   };
-  return LiftTensor(grad, t, e, ll);
+  auto grad_type = [&](const Type& forward_type) {
+    return forward_type;
+  };
+  return LiftTensor(grad, grad_type, forward_type, e, ll);
 }
 
 void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
@@ -337,42 +380,87 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
   }
 }
 
+Expr BPEmpty() {
+  Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
+  return RefCreateNode::make(unitF);
+}
+
 struct ReverseAD : ExprMutator {
+  using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>;
+
   Var bp;
+  std::shared_ptr<ADVarMap> ad_vars;
   const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
 
-  explicit ReverseAD(const Var& bp) : bp(bp) { }
+  explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
+      : bp(bp), ad_vars(ad_vars) { }
 
   Expr VisitExpr_(const OpNode* op) final {
     LOG(FATAL) << "op should only be inside call";
     throw;
   }
 
-  Expr VisitExpr_(const CallNode* op) final {
-    if (const OpNode* op_node = op->op.as<OpNode>()) {
+  Expr VisitCheckpoint(const CallNode *call) {
+    const OpNode* op_node = call->op.as<OpNode>();
+    CHECK(op_node) << "expected op in call";
+    Op op_ref = GetRef<Op>(op_node);
+    CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
+    auto x = call->args[0];
+    return LetList::With([&](LetList* ll) {
+      auto x_var = ll->Push(x);
+      auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
+      auto bpv = ll->Push(RefReadNode::make(bp));
+      Expr nbp = FunctionNode::make(
+        {},
+        LetList::With([&](LetList* ll) {
+          // we need a new ReverseAD visitor to avoid clobbering the bp local var
+          auto dup_bp = ll->Push(BPEmpty());
+          ReverseAD dup_diff(dup_bp, ad_vars);
+          auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
+
+          TransferGrads(call->checked_type(), ret, dup_ad, ll);
+          ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
+          return CallNode::make(bpv, {});
+        }),
+        TupleTypeNode::make({}),
+        {});
+      ll->Push(RefWriteNode::make(bp, nbp));
+      return ret;
+    });
+  }
+
+  Expr VisitExpr_(const CallNode* call) final {
+    if (const OpNode* op_node = call->op.as<OpNode>()) {
       Op op_ref = GetRef<Op>(op_node);
+
+      if (op_ref->name == "annotation.checkpoint") {
+        return VisitCheckpoint(call);
+      }
+
+      CHECK(rev_map.count(op_ref))
+        << op_node->name << " does not have reverse mode defined";
       return LetList::With([&](LetList* ll) {
         std::vector<Var> args;
-        for (const auto& arg : op->args) {
+        for (const auto& arg : call->args) {
           args.push_back(ll->Push(VisitExpr(arg)));
         }
         std::vector<Expr> orig_args;
         for (size_t i = 0; i < args.size(); i++) {
-          orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
+          orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
         }
-        Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
-        orig->checked_type_ = op->checked_type();
+        Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args);
+        orig->checked_type_ = call->checked_type();
         Var orig_var = ll->Push(orig);
-        orig_var->checked_type_ = op->checked_type();
-        auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll));
+        orig_var->checked_type_ = call->checked_type();
+        auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
         auto bpv = ll->Push(RefReadNode::make(bp));
         Expr nbp = FunctionNode::make(
           {},
           LetList::With([&](LetList* ll) {
-            tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll));
+            tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
             CHECK(args.size() == rev.size());
             for (size_t i = 0; i < args.size(); ++i) {
-              UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll);
+              UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
             }
             return CallNode::make(bpv, {});
           }),
@@ -382,7 +470,7 @@ struct ReverseAD : ExprMutator {
         return ret;
       });
     }
-    return ExprMutator::VisitExpr_(op);
+    return ExprMutator::VisitExpr_(call);
   }
 
   Expr VisitExpr_(const ConstantNode* op) final {
@@ -396,16 +484,22 @@ struct ReverseAD : ExprMutator {
                         VisitExpr(op->false_branch));
   }
 
+  Expr VisitExpr_(const VarNode* var) final {
+    // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
+    auto var_ref = GetRef<Var>(var);
+    if (!ad_vars->count(var_ref)) {
+      auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
+      (*ad_vars)[var_ref] = res;
+    }
+
+    return ad_vars->at(var_ref);
+  }
+
   Type VisitType(const Type& t) final {
     return t.defined() ? ReverseType(t) : t;
   }
 };
 
-Expr BPEmpty() {
-  Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
-  return RefCreateNode::make(unitF);
-}
-
 bool MissingGrad(const Expr& e) {
   struct MGVisitor : ExprVisitor {
     const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
@@ -413,7 +507,7 @@ bool MissingGrad(const Expr& e) {
 
     void VisitExpr_(const OpNode* op) final {
       Op op_ref = GetRef<Op>(op);
-      if (!rev_map.count(op_ref)) {
+      if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
         op_names.insert(op_ref->name);
       }
       ExprVisitor::VisitExpr_(op);
@@ -445,7 +539,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
   CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
   Expr body = LetList::With([&](LetList* ll) {
     Var bp = ll->Push(BPEmpty());
-    Expr rev = ReverseAD(bp)(e);
+    Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
     std::vector<Expr> args;
     for (const auto& p : f->params) {
       args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
index 7aa9e0b..acf3b75 100644 (file)
@@ -30,6 +30,18 @@ def test_cross_entropy_with_logits_grad():
     x = relay.var("x", shape=(2, 5))
     y = relay.var("y", shape=(2, 5))
     check_grad(relay.Function([x, y], relay.op.nn.cross_entropy_with_logits(x, y)), eps=0.01, scale=0.1, mean=1)
+    
+def test_checkpoint():
+    inputs = [relay.var("x{}".format(i), shape=(1,)) for i in range(4)]
+    output = relay.multiply(relay.add(inputs[0], inputs[1]),
+                            relay.add(inputs[2], inputs[3]))
+    check_grad(relay.Function(inputs, relay.annotation.checkpoint(output)))
+
+    out_tuple = relay.Tuple([relay.add(inputs[0], inputs[1]),
+                             relay.multiply(inputs[2], inputs[3])])
+    out_single = relay.subtract(relay.TupleGetItem(relay.annotation.checkpoint(out_tuple), 0),
+                                relay.TupleGetItem(out_tuple, 1))
+    check_grad(relay.Function(inputs, out_single))
 
 
 if __name__ == "__main__":
index e828fa3..d9e29d8 100644 (file)
@@ -31,6 +31,127 @@ def run_infer_type(expr):
     entry = mod["main"]
     return entry if isinstance(expr, relay.Function) else entry.body
 
+def test_checkpoint():
+    dtype = "float32"
+    xs = [relay.var("x{}".format(i), dtype) for i in range(4)]
+    f = relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))
+    f_checkpoint = relay.annotation.checkpoint(f)
+
+    func, func_checkpoint = relay.Function(xs, f), relay.Function(xs, f_checkpoint)
+    f, f_checkpoint = run_infer_type(func), run_infer_type(func_checkpoint)
+    assert f.checked_type == f_checkpoint.checked_type
+
+    inputs = [np.random.uniform() for _ in range(len(xs))]
+    for target, ctx in ctx_list():
+        for kind in ["graph", "debug"]:
+            intrp = relay.create_executor(kind, ctx=ctx, target=target)
+            f_res = intrp.evaluate(f)(*inputs)
+            f_checkpoint_res = intrp.evaluate(f_checkpoint)(*inputs)
+            tvm.testing.assert_allclose(f_res.asnumpy(), f_checkpoint_res.asnumpy(), 0, 0)
+
+def test_checkpoint_alpha_equal():
+    xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
+    f = relay.Function(xs, relay.annotation.checkpoint(
+        relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))
+    ))
+    df = transform.gradient(run_infer_type(f))
+
+    # run PE and DCE
+    with transform.PassContext(opt_level=3):
+        passes = [transform.PartialEvaluate(),
+                  transform.DeadCodeElimination(inline_once=True)]
+        mod = transform.Sequential(passes)(relay.Module.from_expr(df))
+        df = mod["main"]
+
+    df_parsed = relay.parser.fromtext(
+        """
+        v0.0.4
+        fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
+            %z: Tensor[(1), float32], %w: Tensor[(1), float32])
+            ->  (Tensor[(1), float32],
+                (Tensor[(1), float32], Tensor[(1), float32],
+                 Tensor[(1), float32], Tensor[(1), float32])) {
+            %0 = add(%x, %y);
+            %1 = add(%z, %w);
+            let %x1: Tensor[(1), float32] = multiply(%0, %1);
+            let %x2: Tensor[(1), float32] = ones_like(%x1);
+            let %x3: Tensor[(1), float32] = add(%x, %y);
+            let %x4: Tensor[(1), float32] = add(%z, %w);
+            %2 = zeros_like(%x3);
+            %3 = multiply(%x2, %x4);
+            %4 = collapse_sum_like(%3, %x3);
+            let %x5: Tensor[(1), float32] = add(%2, %4);
+            %5 = zeros_like(%x4);
+            %6 = multiply(%x2, %x3);
+            %7 = collapse_sum_like(%6, %x4);
+            let %x6: Tensor[(1), float32] = add(%5, %7);
+            %8 = zeros_like(%x);
+            %9 = collapse_sum_like(%x5, %x);
+            %10 = add(%8, %9);
+            %11 = zeros_like(%y);
+            %12 = collapse_sum_like(%x5, %y);
+            %13 = add(%11, %12);
+            %14 = zeros_like(%z);
+            %15 = collapse_sum_like(%x6, %z);
+            %16 = add(%14, %15);
+            %17 = zeros_like(%w);
+            %18 = collapse_sum_like(%x6, %w);
+            %19 = add(%17, %18);
+            %20 = (%10, %13, %16, %19);
+            (%x1, %20)
+        }
+        """
+    )
+
+    relay.analysis.assert_alpha_equal(df, df_parsed)
+
+def test_checkpoint_alpha_equal_tuple():
+    xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
+    f = relay.Function(xs, relay.annotation.checkpoint(
+        relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])
+    ))
+    df = transform.gradient(run_infer_type(f))
+
+    # run PE and DCE
+    with transform.PassContext(opt_level=3):
+        passes = [transform.PartialEvaluate(),
+                  transform.DeadCodeElimination(inline_once=True)]
+        mod = transform.Sequential(passes)(relay.Module.from_expr(df))
+        df = mod["main"]
+
+    df_parsed = relay.parser.fromtext(
+        """
+        v0.0.4
+        fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
+            %z: Tensor[(1), float32], %w: Tensor[(1), float32])
+            -> ((Tensor[(1), float32], Tensor[(1), float32]),
+                (Tensor[(1), float32], Tensor[(1), float32],
+                 Tensor[(1), float32], Tensor[(1), float32])) {
+        let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */;
+        let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */;
+        let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */;
+        let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */;
+        %0 = (%x1, %x2);
+        %1 = zeros_like(%x) /* ty=Tensor[(1), float32] */;
+        %2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */;
+        %3 = add(%1, %2) /* ty=Tensor[(1), float32] */;
+        %4 = zeros_like(%y) /* ty=Tensor[(1), float32] */;
+        %5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */;
+        %6 = add(%4, %5) /* ty=Tensor[(1), float32] */;
+        %7 = zeros_like(%z) /* ty=Tensor[(1), float32] */;
+        %8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */;
+        %9 = add(%7, %8) /* ty=Tensor[(1), float32] */;
+        %10 = zeros_like(%w) /* ty=Tensor[(1), float32] */;
+        %11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */;
+        %12 = add(%10, %11) /* ty=Tensor[(1), float32] */;
+        %13 = (%3, %6, %9, %12);
+        (%0, %13)
+        }
+        """
+    )
+
+    relay.analysis.assert_alpha_equal(df, df_parsed)
+
 def test_collapse_sum_like():
     shape = (3, 4, 5, 6)
     shape_like = (4, 5, 6)