* by doing a structure preserving map.
*/
Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
- const Type& t,
+ const std::function<Type(const Type&)>& tf,
+ const Type& forward_type,
const Expr& e,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
- if (t.as<TensorTypeNode>()) {
+ if (forward_type.as<TensorTypeNode>()) {
auto ret = f(e);
- ret->checked_type_ = t;
+ ret->checked_type_ = tf(forward_type);
return ret;
- } else if (auto* tt = t.as<TupleTypeNode>()) {
+ } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
tvm::Array<Expr> fields;
+ tvm::Array<Type> types;
for (size_t i = 0; i < tt->fields.size(); ++i) {
- fields.push_back(LiftTensor(f,
- tt->fields[i],
- ll->Push(GetField(e, i)),
- ll));
+ auto field = LiftTensor(f,
+ tf,
+ tt->fields[i],
+ ll->Push(GetField(e, i)),
+ ll);
+ fields.push_back(field);
+ types.push_back(field->checked_type_);
}
auto ret = TupleNode::make(fields);
- ret->checked_type_ = t;
+ ret->checked_type_ = TupleTypeNode::make(types);
return std::move(ret);
} else {
LOG(FATAL) << "unsupported input/output type: " << tt;
}
}
+/*! \brief Transfers the gradients from an Expr to a deep duplication of the Expr,
+ * by stitching the references in the AD values.
+ */
+void TransferGrads(const Type& forward_type,
+ const Expr& from,
+ const Expr& to,
+ LetList* ll) {
+ CHECK(IsAtomic(from)) << from;
+ CHECK(IsAtomic(to)) << to;
+ if (forward_type.as<TensorTypeNode>()) {
+ auto from_ref = TupleGetItemNode::make(from, 1);
+ auto to_ref = TupleGetItemNode::make(to, 1);
+ ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref)));
+ } else if (auto* tt = forward_type.as<TupleTypeNode>()) {
+ for (size_t i = 0; i < tt->fields.size(); ++i) {
+ TransferGrads(tt->fields[i],
+ ll->Push(TupleGetItemNode::make(from, i)),
+ ll->Push(TupleGetItemNode::make(to, i)),
+ ll);
+ }
+ } else {
+ LOG(FATAL) << "Unsupported input/output type: " << forward_type;
+ throw;
+ }
+}
+
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
-Expr GetRev(const Type& t, const Expr& e, LetList* ll) {
+Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) {
return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
};
- return LiftTensor(rev, t, e, ll);
+ auto rev_type = [&](const Type& forward_type) {
+ return ReverseType(forward_type);
+ };
+ return LiftTensor(rev, rev_type, forward_type, e, ll);
}
/*! \brief ReverseType(t) -> t. Get the original value. */
-Expr GetValue(const Type& t, const Expr& e, LetList* ll) {
- return LiftTensor([&](const Expr& e) { return GetField(e, 0); }, t, e, ll);
+Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
+ auto val = [&](const Expr& e) {
+ return GetField(e, 0);
+ };
+ auto val_type = [&](const Type& forward_type) {
+ return forward_type;
+ };
+ return LiftTensor(val, val_type, forward_type, e, ll);
}
/*! \brief ReverseType(t) -> t. Get the gradient. */
-Expr GetGrad(const Type& t, const Expr& e, LetList* ll) {
+Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) {
return ll->Push(RefReadNode::make(GetField(e, 1)));
};
- return LiftTensor(grad, t, e, ll);
+ auto grad_type = [&](const Type& forward_type) {
+ return forward_type;
+ };
+ return LiftTensor(grad, grad_type, forward_type, e, ll);
}
void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
}
}
+Expr BPEmpty() {
+ Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
+ return RefCreateNode::make(unitF);
+}
+
struct ReverseAD : ExprMutator {
+ using ADVarMap = std::unordered_map<Var, Var, NodeHash, NodeEqual>;
+
Var bp;
+ std::shared_ptr<ADVarMap> ad_vars;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
- explicit ReverseAD(const Var& bp) : bp(bp) { }
+ explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars)
+ : bp(bp), ad_vars(ad_vars) { }
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
throw;
}
- Expr VisitExpr_(const CallNode* op) final {
- if (const OpNode* op_node = op->op.as<OpNode>()) {
+ Expr VisitCheckpoint(const CallNode *call) {
+ const OpNode* op_node = call->op.as<OpNode>();
+ CHECK(op_node) << "expected op in call";
+ Op op_ref = GetRef<Op>(op_node);
+ CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
+ auto x = call->args[0];
+ return LetList::With([&](LetList* ll) {
+ auto x_var = ll->Push(x);
+ auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
+ auto bpv = ll->Push(RefReadNode::make(bp));
+ Expr nbp = FunctionNode::make(
+ {},
+ LetList::With([&](LetList* ll) {
+ // we need a new ReverseAD visitor to avoid clobbering the bp local var
+ auto dup_bp = ll->Push(BPEmpty());
+ ReverseAD dup_diff(dup_bp, ad_vars);
+ auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
+
+ TransferGrads(call->checked_type(), ret, dup_ad, ll);
+ ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
+ return CallNode::make(bpv, {});
+ }),
+ TupleTypeNode::make({}),
+ {});
+ ll->Push(RefWriteNode::make(bp, nbp));
+ return ret;
+ });
+ }
+
+ Expr VisitExpr_(const CallNode* call) final {
+ if (const OpNode* op_node = call->op.as<OpNode>()) {
Op op_ref = GetRef<Op>(op_node);
+
+ if (op_ref->name == "annotation.checkpoint") {
+ return VisitCheckpoint(call);
+ }
+
+ CHECK(rev_map.count(op_ref))
+ << op_node->name << " does not have reverse mode defined";
return LetList::With([&](LetList* ll) {
std::vector<Var> args;
- for (const auto& arg : op->args) {
+ for (const auto& arg : call->args) {
args.push_back(ll->Push(VisitExpr(arg)));
}
std::vector<Expr> orig_args;
for (size_t i = 0; i < args.size(); i++) {
- orig_args.push_back(GetValue(op->args[i]->checked_type(), args[i], ll));
+ orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
}
- Expr orig = CallNode::make(op->op, orig_args, op->attrs, op->type_args);
- orig->checked_type_ = op->checked_type();
+ Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args);
+ orig->checked_type_ = call->checked_type();
Var orig_var = ll->Push(orig);
- orig_var->checked_type_ = op->checked_type();
- auto ret = ll->Push(GetRev(op->checked_type(), orig_var, ll));
+ orig_var->checked_type_ = call->checked_type();
+ auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
Expr nbp = FunctionNode::make(
{},
LetList::With([&](LetList* ll) {
- tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(op->checked_type(), ret, ll));
+ tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
- UpdateGrad(op->args[i]->checked_type(), args[i], rev[i], ll);
+ UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return CallNode::make(bpv, {});
}),
return ret;
});
}
- return ExprMutator::VisitExpr_(op);
+ return ExprMutator::VisitExpr_(call);
}
Expr VisitExpr_(const ConstantNode* op) final {
VisitExpr(op->false_branch));
}
+ Expr VisitExpr_(const VarNode* var) final {
+ // memoize Var -> ADVar so we don't end up with free Vars when checkpointing
+ auto var_ref = GetRef<Var>(var);
+ if (!ad_vars->count(var_ref)) {
+ auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
+ (*ad_vars)[var_ref] = res;
+ }
+
+ return ad_vars->at(var_ref);
+ }
+
Type VisitType(const Type& t) final {
return t.defined() ? ReverseType(t) : t;
}
};
-Expr BPEmpty() {
- Expr unitF = FunctionNode::make({}, TupleNode::make({}), TupleTypeNode::make({}), {});
- return RefCreateNode::make(unitF);
-}
-
bool MissingGrad(const Expr& e) {
struct MGVisitor : ExprVisitor {
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
void VisitExpr_(const OpNode* op) final {
Op op_ref = GetRef<Op>(op);
- if (!rev_map.count(op_ref)) {
+ if (op_ref->name != "annotation.checkpoint" && !rev_map.count(op_ref)) {
op_names.insert(op_ref->name);
}
ExprVisitor::VisitExpr_(op);
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
- Expr rev = ReverseAD(bp)(e);
+ Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
+def test_checkpoint():
+ dtype = "float32"
+ xs = [relay.var("x{}".format(i), dtype) for i in range(4)]
+ f = relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))
+ f_checkpoint = relay.annotation.checkpoint(f)
+
+ func, func_checkpoint = relay.Function(xs, f), relay.Function(xs, f_checkpoint)
+ f, f_checkpoint = run_infer_type(func), run_infer_type(func_checkpoint)
+ assert f.checked_type == f_checkpoint.checked_type
+
+ inputs = [np.random.uniform() for _ in range(len(xs))]
+ for target, ctx in ctx_list():
+ for kind in ["graph", "debug"]:
+ intrp = relay.create_executor(kind, ctx=ctx, target=target)
+ f_res = intrp.evaluate(f)(*inputs)
+ f_checkpoint_res = intrp.evaluate(f_checkpoint)(*inputs)
+ tvm.testing.assert_allclose(f_res.asnumpy(), f_checkpoint_res.asnumpy(), 0, 0)
+
+def test_checkpoint_alpha_equal():
+ xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
+ f = relay.Function(xs, relay.annotation.checkpoint(
+ relay.multiply(relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3]))
+ ))
+ df = transform.gradient(run_infer_type(f))
+
+ # run PE and DCE
+ with transform.PassContext(opt_level=3):
+ passes = [transform.PartialEvaluate(),
+ transform.DeadCodeElimination(inline_once=True)]
+ mod = transform.Sequential(passes)(relay.Module.from_expr(df))
+ df = mod["main"]
+
+ df_parsed = relay.parser.fromtext(
+ """
+ v0.0.4
+ fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
+ %z: Tensor[(1), float32], %w: Tensor[(1), float32])
+ -> (Tensor[(1), float32],
+ (Tensor[(1), float32], Tensor[(1), float32],
+ Tensor[(1), float32], Tensor[(1), float32])) {
+ %0 = add(%x, %y);
+ %1 = add(%z, %w);
+ let %x1: Tensor[(1), float32] = multiply(%0, %1);
+ let %x2: Tensor[(1), float32] = ones_like(%x1);
+ let %x3: Tensor[(1), float32] = add(%x, %y);
+ let %x4: Tensor[(1), float32] = add(%z, %w);
+ %2 = zeros_like(%x3);
+ %3 = multiply(%x2, %x4);
+ %4 = collapse_sum_like(%3, %x3);
+ let %x5: Tensor[(1), float32] = add(%2, %4);
+ %5 = zeros_like(%x4);
+ %6 = multiply(%x2, %x3);
+ %7 = collapse_sum_like(%6, %x4);
+ let %x6: Tensor[(1), float32] = add(%5, %7);
+ %8 = zeros_like(%x);
+ %9 = collapse_sum_like(%x5, %x);
+ %10 = add(%8, %9);
+ %11 = zeros_like(%y);
+ %12 = collapse_sum_like(%x5, %y);
+ %13 = add(%11, %12);
+ %14 = zeros_like(%z);
+ %15 = collapse_sum_like(%x6, %z);
+ %16 = add(%14, %15);
+ %17 = zeros_like(%w);
+ %18 = collapse_sum_like(%x6, %w);
+ %19 = add(%17, %18);
+ %20 = (%10, %13, %16, %19);
+ (%x1, %20)
+ }
+ """
+ )
+
+ relay.analysis.assert_alpha_equal(df, df_parsed)
+
+def test_checkpoint_alpha_equal_tuple():
+ xs = [relay.var("x{}".format(i), relay.TensorType((1,), "float32")) for i in range(4)]
+ f = relay.Function(xs, relay.annotation.checkpoint(
+ relay.Tuple([relay.add(xs[0], xs[1]), relay.add(xs[2], xs[3])])
+ ))
+ df = transform.gradient(run_infer_type(f))
+
+ # run PE and DCE
+ with transform.PassContext(opt_level=3):
+ passes = [transform.PartialEvaluate(),
+ transform.DeadCodeElimination(inline_once=True)]
+ mod = transform.Sequential(passes)(relay.Module.from_expr(df))
+ df = mod["main"]
+
+ df_parsed = relay.parser.fromtext(
+ """
+ v0.0.4
+ fn (%x: Tensor[(1), float32], %y: Tensor[(1), float32],
+ %z: Tensor[(1), float32], %w: Tensor[(1), float32])
+ -> ((Tensor[(1), float32], Tensor[(1), float32]),
+ (Tensor[(1), float32], Tensor[(1), float32],
+ Tensor[(1), float32], Tensor[(1), float32])) {
+ let %x1: Tensor[(1), float32] = add(%x, %y) /* ty=Tensor[(1), float32] */;
+ let %x2: Tensor[(1), float32] = add(%z, %w) /* ty=Tensor[(1), float32] */;
+ let %x3: Tensor[(1), float32] = zeros_like(%x2) /* ty=Tensor[(1), float32] */;
+ let %x4: Tensor[(1), float32] = ones_like(%x1) /* ty=Tensor[(1), float32] */;
+ %0 = (%x1, %x2);
+ %1 = zeros_like(%x) /* ty=Tensor[(1), float32] */;
+ %2 = collapse_sum_like(%x4, %x) /* ty=Tensor[(1), float32] */;
+ %3 = add(%1, %2) /* ty=Tensor[(1), float32] */;
+ %4 = zeros_like(%y) /* ty=Tensor[(1), float32] */;
+ %5 = collapse_sum_like(%x4, %y) /* ty=Tensor[(1), float32] */;
+ %6 = add(%4, %5) /* ty=Tensor[(1), float32] */;
+ %7 = zeros_like(%z) /* ty=Tensor[(1), float32] */;
+ %8 = collapse_sum_like(%x3, %z) /* ty=Tensor[(1), float32] */;
+ %9 = add(%7, %8) /* ty=Tensor[(1), float32] */;
+ %10 = zeros_like(%w) /* ty=Tensor[(1), float32] */;
+ %11 = collapse_sum_like(%x3, %w) /* ty=Tensor[(1), float32] */;
+ %12 = add(%10, %11) /* ty=Tensor[(1), float32] */;
+ %13 = (%3, %6, %9, %12);
+ (%0, %13)
+ }
+ """
+ )
+
+ relay.analysis.assert_alpha_equal(df, df_parsed)
+
def test_collapse_sum_like():
shape = (3, 4, 5, 6)
shape_like = (4, 5, 6)